参考:
MOE原理解释及从零实现一个MOE(专家混合模型)_moe代码-CSDN博客
MoE环游记:1、从几何意义出发 - 科学空间|Scientific Spaces
深度学习之图像分类(二十八)-- Sparse-MLP(MoE)网络详解_sparse moe-CSDN博客
深度学习之图像分类(二十九)-- Sparse-MLP网络详解_sparse mlp-CSDN博客
代码如下:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# 超参数设置
num_experts = 4 # 专家数量
top_k = 2 # 激活专家数
# input_dim = 3072 # CIFAR-10图像展平后维度(32x32x3)
input_dim = 64 * 8 * 8
hidden_dim = 512 # 专家网络隐藏层维度
num_classes = 10 # 分类类别数
# MoE层实现(文献[5][7])
class SparseMoE(nn.Module):
def __init__(self):
super().__init__()
self.experts = nn.ModuleList([
nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim)
) for _ in range(num_experts)])
self.gate = nn.Sequential(
nn.Linear(input_dim, num_experts),
nn.Softmax(dim=1)
)
# 负载均衡参数(文献[4][7])
self.balance_loss_weight = 0.01
self.register_buffer('expert_counts', torch.zeros(num_experts))
def forward(self, x):
# 门控计算
gate_scores = self.gate(x) # [B, num_experts]
# Top-k选择(文献[5])
topk_scores, topk_indices = torch.topk(gate_scores, top_k, dim=1)
mask = F.one_hot(topk_indices, num_experts).float().sum(dim=1)
# 专家输出聚合
expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=1)
selected_experts = expert_outputs.gather(1, topk_indices.unsqueeze(-1).expand(-1, -1, hidden_dim)) # [B, 2, H]
# print(f"专家输出维度: {expert_outputs.shape}")
# print(f"选择索引维度: {topk_indices.shape}")
# print(f"选择专家维度: {selected_experts.shape}")
weighted_outputs = (selected_experts * topk_scores.unsqueeze(-1)).sum(dim=1)
# 更新专家使用统计
self.expert_counts += mask.sum(dim=0)
return weighted_outputs
def balance_loss(self):
# 计算负载均衡损失(文献[4][7])
expert_probs = self.expert_counts / self.expert_counts.sum()
balance_loss = torch.std(expert_probs) * self.balance_loss_weight
self.expert_counts.zero_() # 重置计数器
return balance_loss
# 完整模型架构(文献[2][6])
class MoEImageClassifier(nn.Module):
def __init__(self):
super().__init__()
self.feature_extractor = nn.Sequential(
nn.Conv2d(3, 32, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2)
)
self.moe_layer = SparseMoE()
self.classifier = nn.Linear(hidden_dim, num_classes)
def forward(self, x):
x = self.feature_extractor(x)
x = x.view(x.size(0), -1) # 展平特征
x = self.moe_layer(x)
return self.classifier(x)
# 数据预处理(文献[2])
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_set = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_set, batch_size=64, shuffle=True)
# 训练流程
model = MoEImageClassifier()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
for epoch in range(10):
for images, labels in train_loader:
optimizer.zero_grad()
outputs = model(images)
main_loss = criterion(outputs, labels)
balance_loss = model.moe_layer.balance_loss()
total_loss = main_loss + balance_loss
total_loss.backward()
optimizer.step()
print(f'Epoch [{epoch+1}/10], Loss: {total_loss.item():.4f}')