基于动态Teacher Forcing策略的Transformer
Teacher Forcing策略
一般来说,Transformer会接受训练数据,然后根据训练数据生成预测序列。先预测下一时间步,然后按照预测的时间步再生成再下一个时间步的数据。
会产生一个问题,模型会犯错,会跑偏,没有正确的数据介入来使他自己调整自己是不行的。
Teacher Forcing人如其名,学生(模型)犯的错误会被老师(正确数据)进行一定程度的纠正,然后再用这个纠正后的数据进行下一步推测,提高预测准确性。
好处很明显,提高了长期预测的准确性,而且可以不断更新模型参数,优化模型。
也有坏处,一直有老师的帮助,模型会过于自信,对于极端情况(极端气候)的预测不准确。
动态的Teacher Forcing
对于这种,我们可以进行一定程度的参数微调。模型刚起步的阶段,给教师数据较大的权重,时间长了就把教师权重调低。
Forward && Predict
那么我们现在就有两个情况,一种是训练模型的时候,我们除了最开始训练集的数据,还可以从传感器获取训练集最后时间以后的数据,Teacher存在,我们可以利用他来优化参数
还有一种,就是我们真的要将其运用于实际应用中,将按照一般方式工作(没有teacher)
Implemention
模型构成
class DynamicTransformer(nn.Module):
def __init__(
self,
input_dim: int,
output_dim: int,
d_model: int = 512,
nhead: int = 8,
num_encoder_layers: int = 6,
num_decoder_layers: int = 6,
dropout: float = 0.1,
tf_schedule_params: dict = None
):
super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.d_model = d_model
# 线性映射,将输入序列的维度作为d_model
# 还没搞懂为啥iTransformer要用这个当输入维度
self.input_projection = nn.Linear(input_dim, d_model)
self.output_projection = nn.Linear(d_model, output_dim)
# 位置编码,Transformer没办法自己识别位置
# 嵌入位置编码
self.pos_encoder = PositionalEncoding(d_model)
# Pytorch内置的Transformer
self.transformer = nn.Transformer(
d_model=d_model,
nhead=nhead,
num_encoder_layers=num_encoder_layers,
num_decoder_layers=num_decoder_layers,
dropout=dropout
)
# Teacher Forcing 调参
self.tf_schedule_params = tf_schedule_params or {
'initial_ratio': 0.9,
'min_ratio': 0.1,
'decay_factor': 0.995,
'schedule_type': 'exponential' # 'exponential', 'linear', or 'cyclical'
}
self.current_tf_ratio = self.tf_schedule_params['initial_ratio']
self.training_steps = 0
# 记录loss信息
self.loss_history = []
self.patience_counter = 0
self.best_loss = float('inf')
Forward
def forward(
self,
src: torch.Tensor,
tgt: torch.Tensor,
use_teacher_forcing: bool = True
) -> torch.Tensor:
"""
前向传播,支持混合Teacher Forcing策略
src: [batch_size, seq_len, input_dim]
tgt: [batch_size, tgt_seq_len, output_dim]
"""
# 屏蔽填充Token
src_mask = None
tgt_mask = self.create_mask(tgt.size(1)).to(tgt.device)
# 线性映射
src = self.input_projection(src)
src = self.pos_encoder(src)
src = src.transpose(0, 1) # [seq_len, batch_size, d_model]
batch_size = src.size(1)
tgt_len = tgt.size(1)
# 初始化解码器输入(使用序列的最后一个时间步)
decoder_input = src[-1:].transpose(0, 1) # [batch_size, 1, d_model]
decoder_input = self.input_projection(decoder_input)
decoder_input = decoder_input.transpose(0, 1) # [1, batch_size, d_model]
outputs = []
for i in range(tgt_len):
# 为当前时间步生成预测
step_output = self.transformer(
src,
decoder_input,
src_mask=src_mask,
tgt_mask=self.create_mask(decoder_input.size(0)).to(src.device)
)
# 获取最新的预测
current_pred = step_output[-1:].transpose(0, 1) # [batch_size, 1, d_model]
current_pred = self.output_projection(current_pred)
outputs.append(current_pred)
# 决定下一个时间步使用的输入
if use_teacher_forcing and i < tgt_len - 1: # 对于最后一个时间步不需要准备下一步输入
# 对每个样本分别决定是否使用teacher forcing
tf_mask = (torch.rand(batch_size, 1, 1).to(tgt.device) < self.current_tf_ratio)
# 准备真实值
true_next = tgt[:, i:i+1, :] # [batch_size, 1, output_dim]
true_next = self.input_projection(true_next) # [batch_size, 1, d_model]
# 准备预测值
pred_next = current_pred
pred_next = self.input_projection(pred_next) # [batch_size, 1, d_model]
# 混合真实值和预测值
next_input = torch.where(tf_mask, true_next, pred_next)
next_input = next_input.transpose(0, 1) # [1, batch_size, d_model]
# 更新解码器输入
decoder_input = torch.cat([decoder_input, next_input], dim=0)
else:
# 完全使用预测值
pred_next = self.input_projection(current_pred)
decoder_input = torch.cat([
decoder_input,
pred_next.transpose(0, 1)
], dim=0)
# 合并所有时间步的输出
output = torch.cat(outputs, dim=1)
return output
这里讲讲Teacher Forcing怎么起作用的
tf_mask = (torch.rand(batch_size, 1, 1).to(tgt.device) < self.current_tf_ratio)
tf_mask是一个布尔值序列,通过判断生成的随机数与current_tf_ratio的大小对比,生成一个布尔值矩阵。
current_tf_ratio越大,矩阵里面的True占比就越大,true_next在next_input里面占比也就越大
transpose(0, 1)是因为Transformer需要的维度顺序与我们之间输入的不同,这里做一个变换
Predict
def predict(
self,
src: torch.Tensor,
prediction_length: int,
device: torch.device
) -> torch.Tensor:
"""
生成预测序列
"""
self.eval()
with torch.no_grad():
# 准备源序列
src = self.input_projection(src)
src = self.pos_encoder(src)
src = src.transpose(0, 1) # [seq_len, batch_size, d_model]
# 初始化解码器输入
decoder_input = src[-1:].transpose(0, 1) # [batch_size, 1, d_model]
decoder_input = self.input_projection(decoder_input)
decoder_input = decoder_input.transpose(0, 1) # [1, batch_size, d_model]
predictions = []
# 逐步生成预测
for _ in range(prediction_length):
tgt_mask = self.create_mask(decoder_input.size(0)).to(device)
output = self.transformer(
src,
decoder_input,
src_mask=None,
tgt_mask=tgt_mask
)
# 获取最后一个时间步的预测
next_pred = output[-1:].transpose(0, 1) # [batch_size, 1, d_model]
next_pred = self.output_projection(next_pred)
predictions.append(next_pred)
# 更新解码器输入
next_pred = self.input_projection(next_pred)
decoder_input = torch.cat([
decoder_input,
next_pred.transpose(0, 1)
], dim=0)
predictions = torch.cat(predictions, dim=1)
return predictions
将最后一个时间步输入transformer,再将预测出的数据作为transformer下一个时间步的输入,典型的时间预测