别再只盯着论文了!手把手教你用Python复现STGCN交通预测模型(附PeMS数据集实战)

张开发
2026/4/19 21:00:16 15 分钟阅读

分享文章

别再只盯着论文了!手把手教你用Python复现STGCN交通预测模型(附PeMS数据集实战)
从零构建STGCN用Python实现交通流量预测的完整指南当我在第一次尝试复现STGCN论文时面对复杂的图卷积和时间序列处理整整两周都陷入数据预处理的泥潭。直到某个凌晨3点当第一个预测曲线终于与测试数据基本吻合时那种突破困境的成就感至今难忘。本文将带你避开我踩过的所有坑用最直接的方式实现这个强大的时空图卷积网络。1. 环境准备与数据获取在开始编码前我们需要配置合适的开发环境。我强烈建议使用Anaconda创建独立的Python环境避免依赖冲突conda create -n stgcn python3.8 conda activate stgcn pip install torch torch-geometric pandas numpy matplotlibPeMS数据集是交通预测领域的基准数据集我们将使用PeMSD4子集。这个版本包含旧金山湾区3848个传感器站点的59天交通流量数据。原始数据需要从加州交通局官网申请获取但幸运的是预处理好的版本可以直接下载import pandas as pd # 加载数据集 data_url https://github.com/Davidham3/ASTGCN/raw/master/data/PEMS04/PEMS04.npz df pd.read_npz(data_url) print(f数据集形状{df.shape}) # 应该显示(16992, 307, 3)数据维度解析16992个时间点每5分钟一个记录共59天307个传感器站点3个特征流量、速度、占有率2. 数据预处理关键技术原始数据不能直接输入模型需要经过几个关键处理步骤2.1 数据标准化交通数据不同站点的量纲差异很大必须进行标准化from sklearn.preprocessing import StandardScaler scaler StandardScaler() data df[..., 0] # 只取流量特征 scaled_data scaler.fit_transform(data.reshape(-1, 1)).reshape(data.shape)2.2 构建图结构STGCN的核心创新在于将交通网络视为图结构。我们使用高斯核函数计算站点间的空间关系import numpy as np def calculate_adjacency_matrix(locations, sigma0.1): 基于传感器位置计算邻接矩阵 :param locations: (N, 2)维度的经纬度坐标 :param sigma: 高斯核参数 :return: 归一化的邻接矩阵 dist_matrix np.zeros((len(locations), len(locations))) for i in range(len(locations)): for j in range(len(locations)): dist np.linalg.norm(locations[i] - locations[j]) dist_matrix[i][j] np.exp(-dist**2 / sigma**2) # 归一化处理 D np.diag(np.sum(dist_matrix, axis1)) D_inv np.linalg.inv(D) adj_normalized np.dot(np.dot(D_inv**0.5, dist_matrix), D_inv**0.5) return adj_normalized2.3 时间序列切片我们需要将连续的时间序列转换为监督学习所需的样本格式def create_dataset(data, seq_len, pred_len): 创建时间序列样本 :param data: 输入数据 (T, N) :param seq_len: 历史序列长度 :param pred_len: 预测长度 :return: (样本数, seq_len, N), (样本数, pred_len, N) total_len seq_len pred_len result_x, result_y [], [] for i in range(len(data) - total_len 1): result_x.append(data[i:iseq_len]) result_y.append(data[iseq_len:itotal_len]) return np.array(result_x), np.array(result_y)3. STGCN模型架构实现现在来到最核心的部分——实现STGCN模型。我们将使用PyTorch框架构建这个包含时空卷积块的网络。3.1 时空卷积块STGCN由多个时空块堆叠而成每个块包含时间门控卷积空间图卷积时间卷积import torch import torch.nn as nn import torch.nn.functional as F class STConvBlock(nn.Module): def __init__(self, in_channels, spatial_channels, out_channels, num_nodes): super(STConvBlock, self).__init__() # 时间门控卷积 self.temporal1 nn.Conv2d(in_channels, out_channels, kernel_size(1, 3), padding(0, 1)) self.ln1 nn.LayerNorm([num_nodes, out_channels]) # 空间图卷积 self.spatial nn.Conv2d(out_channels, spatial_channels, kernel_size1) self.ln2 nn.LayerNorm([num_nodes, spatial_channels]) # 时间卷积 self.temporal2 nn.Conv2d(spatial_channels, out_channels, kernel_size(1, 3), padding(0, 1)) self.ln3 nn.LayerNorm([num_nodes, out_channels]) self.relu nn.ReLU() def forward(self, x, A): # 输入x形状: (batch, seq_len, num_nodes, in_channels) x x.permute(0, 3, 1, 2) # (batch, in_channels, seq_len, num_nodes) # 第一层时间卷积 x self.temporal1(x) x x.permute(0, 2, 3, 1) # (batch, seq_len, num_nodes, out_channels) x self.ln1(x) x self.relu(x) x x.permute(0, 3, 1, 2) # 恢复维度 # 空间图卷积 x self.spatial(x) x torch.einsum(ncwl,vw-ncvl, (x, A)) # 图卷积运算 x x.permute(0, 2, 3, 1) x self.ln2(x) x self.relu(x) x x.permute(0, 3, 1, 2) # 第二层时间卷积 x self.temporal2(x) x x.permute(0, 2, 3, 1) x self.ln3(x) return self.relu(x)3.2 完整模型组装将多个时空块与输出层组合成完整模型class STGCN(nn.Module): def __init__(self, num_nodes, num_features, num_timesteps_input, num_timesteps_output): super(STGCN, self).__init__() self.block1 STConvBlock(num_features, 16, 64, num_nodes) self.block2 STConvBlock(64, 16, 64, num_nodes) # 输出层 self.final_conv nn.Conv2d(num_timesteps_input, num_timesteps_output, kernel_size(1, 64)) self.fc nn.Linear(num_nodes, num_nodes) def forward(self, x, A): # x形状: (batch, seq_len, num_nodes, num_features) x self.block1(x, A) x self.block2(x, A) # 输出处理 x x.permute(0, 2, 1, 3) # (batch, seq_len, num_nodes, 64) - (batch, num_nodes, seq_len, 64) x self.final_conv(x) # (batch, num_nodes, pred_len, 1) x x.squeeze(-1).permute(0, 2, 1) # (batch, pred_len, num_nodes) x self.fc(x) # 全连接层调整维度 return x4. 模型训练与评估有了模型和数据现在可以开始训练过程了。这里有几个关键技巧需要注意4.1 训练配置import torch.optim as optim from torch.utils.data import DataLoader, TensorDataset # 准备数据加载器 train_x_tensor torch.FloatTensor(train_x) train_y_tensor torch.FloatTensor(train_y) train_dataset TensorDataset(train_x_tensor, train_y_tensor) train_loader DataLoader(train_dataset, batch_size64, shuffleTrue) # 初始化模型 model STGCN(num_nodes307, num_features1, num_timesteps_input12, num_timesteps_output3) optimizer optim.Adam(model.parameters(), lr0.001) loss_fn nn.MSELoss()4.2 训练循环实现带有早停机制的训练过程def train_model(model, train_loader, val_loader, optimizer, loss_fn, epochs100): best_val_loss float(inf) patience 10 counter 0 for epoch in range(epochs): model.train() train_loss 0 for x_batch, y_batch in train_loader: optimizer.zero_grad() output model(x_batch, A_tensor) loss loss_fn(output, y_batch) loss.backward() optimizer.step() train_loss loss.item() # 验证阶段 model.eval() val_loss 0 with torch.no_grad(): for x_val, y_val in val_loader: output model(x_val, A_tensor) val_loss loss_fn(output, y_val).item() print(fEpoch {epoch1}: Train Loss {train_loss/len(train_loader):.4f}, Val Loss {val_loss/len(val_loader):.4f}) # 早停机制 if val_loss best_val_loss: best_val_loss val_loss counter 0 torch.save(model.state_dict(), best_model.pth) else: counter 1 if counter patience: print(Early stopping triggered) break4.3 评估指标交通预测常用三种评估指标def evaluate_metrics(y_true, y_pred): 计算MAE、RMSE和MAPE :param y_true: 真实值 (batch, pred_len, num_nodes) :param y_pred: 预测值 (batch, pred_len, num_nodes) :return: 三个指标值 mae np.mean(np.abs(y_true - y_pred)) rmse np.sqrt(np.mean((y_true - y_pred)**2)) mape np.mean(np.abs((y_true - y_pred) / (y_true 1e-5))) * 100 # 避免除以0 return mae, rmse, mape5. 高级技巧与优化建议经过基础实现后下面分享几个提升模型性能的实战技巧5.1 多任务学习同时预测流量、速度和占有率三个指标可以提升模型泛化能力class MultiTaskSTGCN(nn.Module): def __init__(self, num_nodes, num_features, num_timesteps_input, num_timesteps_output): super().__init__() # 共享的特征提取层 self.shared_block1 STConvBlock(num_features, 16, 64, num_nodes) self.shared_block2 STConvBlock(64, 16, 64, num_nodes) # 任务特定输出层 self.flow_head nn.Sequential( nn.Conv2d(num_timesteps_input, num_timesteps_output, kernel_size(1, 64)), nn.Linear(num_nodes, num_nodes) ) self.speed_head nn.Sequential( nn.Conv2d(num_timesteps_input, num_timesteps_output, kernel_size(1, 64)), nn.Linear(num_nodes, num_nodes) ) def forward(self, x, A): shared_features self.shared_block2(self.shared_block1(x, A), A) shared_features shared_features.permute(0, 2, 1, 3) flow self.flow_head(shared_features) speed self.speed_head(shared_features) return flow, speed5.2 动态图卷积静态邻接矩阵无法反映交通关系的动态变化我们可以实现自适应图学习class AdaptiveGraphConv(nn.Module): def __init__(self, num_nodes, k3): super().__init__() self.node_emb1 nn.Parameter(torch.randn(num_nodes, 10)) self.node_emb2 nn.Parameter(torch.randn(10, num_nodes)) self.k k # 邻居数量 def forward(self, x): # x形状: (batch, channels, seq_len, num_nodes) adj torch.softmax(torch.mm(self.node_emb1, self.node_emb2), dim-1) # 只保留top-k连接 topk torch.topk(adj, self.k, dim-1) mask torch.zeros_like(adj) mask.scatter_(-1, topk.indices, topk.values) adj adj * mask # 对称归一化 adj adj adj.t() D torch.diag(torch.sum(adj, dim-1)) D_inv_sqrt torch.inverse(D)**0.5 adj_normalized torch.mm(torch.mm(D_inv_sqrt, adj), D_inv_sqrt) return torch.einsum(ncwl,vw-ncvl, (x, adj_normalized))5.3 混合精度训练使用AMP(自动混合精度)可以大幅减少显存占用并加速训练from torch.cuda.amp import GradScaler, autocast scaler GradScaler() for x_batch, y_batch in train_loader: optimizer.zero_grad() with autocast(): output model(x_batch, A_tensor) loss loss_fn(output, y_batch) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()在NVIDIA V100显卡上这种技术可以使训练速度提升约2倍同时显存占用减少40%。6. 部署与生产化建议当模型达到满意性能后下一步是考虑如何部署到生产环境6.1 模型量化quantized_model torch.quantization.quantize_dynamic( model, {nn.Linear, nn.Conv2d}, dtypetorch.qint8 ) torch.jit.save(torch.jit.script(quantized_model), stgcn_quantized.pt)量化后的模型大小通常可以减少为原来的1/4推理速度提升2-3倍。6.2 ONNX导出dummy_input torch.randn(1, 12, 307, 1) torch.onnx.export( model, (dummy_input, A_tensor), stgcn.onnx, input_names[input, adjacency], output_names[output], dynamic_axes{ input: {0: batch_size}, output: {0: batch_size} } )ONNX格式的模型可以在多种推理引擎上运行如TensorRT、ONNX Runtime等。6.3 缓存机制实现在实际应用中可以设计缓存策略减少重复计算from functools import lru_cache lru_cache(maxsize100) def predict_with_cache(model, input_data, adj_matrix): with torch.no_grad(): return model(input_data, adj_matrix)对于周期性较强的交通数据这种缓存可以显著降低计算负载。

更多文章