基于动态Teacher Forcing策略的Transformer

Teacher Forcing策略

一般来说,Transformer会接受训练数据,然后根据训练数据生成预测序列。先预测下一时间步,然后按照预测的时间步再生成再下一个时间步的数据。

会产生一个问题,模型会犯错,会跑偏,没有正确的数据介入来使他自己调整自己是不行的。

Teacher Forcing人如其名,学生(模型)犯的错误会被老师(正确数据)进行一定程度的纠正,然后再用这个纠正后的数据进行下一步推测,提高预测准确性。

好处很明显,提高了长期预测的准确性,而且可以不断更新模型参数,优化模型。

也有坏处,一直有老师的帮助,模型会过于自信,对于极端情况(极端气候)的预测不准确。

动态的Teacher Forcing

对于这种,我们可以进行一定程度的参数微调。模型刚起步的阶段,给教师数据较大的权重,时间长了就把教师权重调低。

Forward && Predict

那么我们现在就有两个情况,一种是训练模型的时候,我们除了最开始训练集的数据,还可以从传感器获取训练集最后时间以后的数据,Teacher存在,我们可以利用他来优化参数

还有一种,就是我们真的要将其运用于实际应用中,将按照一般方式工作(没有teacher)

Implemention

模型构成

class DynamicTransformer(nn.Module):
    def __init__(
            self,
            input_dim: int,
            output_dim: int,
            d_model: int = 512,
            nhead: int = 8,
            num_encoder_layers: int = 6,
            num_decoder_layers: int = 6,
            dropout: float = 0.1,
            tf_schedule_params: dict = None
    ):
        super().__init__()

        self.input_dim = input_dim
        self.output_dim = output_dim
        self.d_model = d_model

        # 线性映射,将输入序列的维度作为d_model
        # 还没搞懂为啥iTransformer要用这个当输入维度
        self.input_projection = nn.Linear(input_dim, d_model)
        self.output_projection = nn.Linear(d_model, output_dim)

        # 位置编码,Transformer没办法自己识别位置
        # 嵌入位置编码
        self.pos_encoder = PositionalEncoding(d_model)

        # Pytorch内置的Transformer
        self.transformer = nn.Transformer(
            d_model=d_model,
            nhead=nhead,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dropout=dropout
        )

        # Teacher Forcing 调参
        self.tf_schedule_params = tf_schedule_params or {
            'initial_ratio': 0.9,
            'min_ratio': 0.1,
            'decay_factor': 0.995,
            'schedule_type': 'exponential'  # 'exponential', 'linear', or 'cyclical'
        }

        self.current_tf_ratio = self.tf_schedule_params['initial_ratio']
        self.training_steps = 0

        # 记录loss信息
        self.loss_history = []
        self.patience_counter = 0
        self.best_loss = float('inf')

Forward

def forward(
            self,
            src: torch.Tensor,
            tgt: torch.Tensor,
            use_teacher_forcing: bool = True
    ) -> torch.Tensor:
        """
        前向传播,支持混合Teacher Forcing策略
        src: [batch_size, seq_len, input_dim]
        tgt: [batch_size, tgt_seq_len, output_dim]
        """
        # 屏蔽填充Token
        src_mask = None
        tgt_mask = self.create_mask(tgt.size(1)).to(tgt.device)

        # 线性映射
        src = self.input_projection(src)
        src = self.pos_encoder(src)
        src = src.transpose(0, 1)  # [seq_len, batch_size, d_model]

        batch_size = src.size(1)
        tgt_len = tgt.size(1)

        # 初始化解码器输入(使用序列的最后一个时间步)
        decoder_input = src[-1:].transpose(0, 1)  # [batch_size, 1, d_model]
        decoder_input = self.input_projection(decoder_input)
        decoder_input = decoder_input.transpose(0, 1)  # [1, batch_size, d_model]

        outputs = []

        for i in range(tgt_len):
            # 为当前时间步生成预测
            step_output = self.transformer(
                src,
                decoder_input,
                src_mask=src_mask,
                tgt_mask=self.create_mask(decoder_input.size(0)).to(src.device)
            )

            # 获取最新的预测
            current_pred = step_output[-1:].transpose(0, 1)  # [batch_size, 1, d_model]
            current_pred = self.output_projection(current_pred)
            outputs.append(current_pred)

            # 决定下一个时间步使用的输入
            if use_teacher_forcing and i < tgt_len - 1:  # 对于最后一个时间步不需要准备下一步输入
                # 对每个样本分别决定是否使用teacher forcing
                tf_mask = (torch.rand(batch_size, 1, 1).to(tgt.device) < self.current_tf_ratio)

                # 准备真实值
                true_next = tgt[:, i:i+1, :]  # [batch_size, 1, output_dim]
                true_next = self.input_projection(true_next)  # [batch_size, 1, d_model]

                # 准备预测值
                pred_next = current_pred
                pred_next = self.input_projection(pred_next)  # [batch_size, 1, d_model]

                # 混合真实值和预测值
                next_input = torch.where(tf_mask, true_next, pred_next)
                next_input = next_input.transpose(0, 1)  # [1, batch_size, d_model]

                # 更新解码器输入
                decoder_input = torch.cat([decoder_input, next_input], dim=0)
            else:
                # 完全使用预测值
                pred_next = self.input_projection(current_pred)
                decoder_input = torch.cat([
                    decoder_input,
                    pred_next.transpose(0, 1)
                ], dim=0)

        # 合并所有时间步的输出
        output = torch.cat(outputs, dim=1)

        return output

这里讲讲Teacher Forcing怎么起作用的

tf_mask = (torch.rand(batch_size, 1, 1).to(tgt.device) < self.current_tf_ratio)

tf_mask是一个布尔值序列,通过判断生成的随机数与current_tf_ratio的大小对比,生成一个布尔值矩阵。

current_tf_ratio越大,矩阵里面的True占比就越大,true_next在next_input里面占比也就越大

transpose(0, 1)是因为Transformer需要的维度顺序与我们之间输入的不同,这里做一个变换

Predict

def predict(
            self,
            src: torch.Tensor,
            prediction_length: int,
            device: torch.device
    ) -> torch.Tensor:
        """
        生成预测序列
        """
        self.eval()

        with torch.no_grad():
            # 准备源序列
            src = self.input_projection(src)
            src = self.pos_encoder(src)
            src = src.transpose(0, 1)  # [seq_len, batch_size, d_model]

            # 初始化解码器输入
            decoder_input = src[-1:].transpose(0, 1)  # [batch_size, 1, d_model]
            decoder_input = self.input_projection(decoder_input)
            decoder_input = decoder_input.transpose(0, 1)  # [1, batch_size, d_model]

            predictions = []

            # 逐步生成预测
            for _ in range(prediction_length):
                tgt_mask = self.create_mask(decoder_input.size(0)).to(device)

                output = self.transformer(
                    src,
                    decoder_input,
                    src_mask=None,
                    tgt_mask=tgt_mask
                )

                # 获取最后一个时间步的预测
                next_pred = output[-1:].transpose(0, 1)  # [batch_size, 1, d_model]
                next_pred = self.output_projection(next_pred)
                predictions.append(next_pred)

                # 更新解码器输入
                next_pred = self.input_projection(next_pred)
                decoder_input = torch.cat([
                    decoder_input,
                    next_pred.transpose(0, 1)
                ], dim=0)

            predictions = torch.cat(predictions, dim=1)

        return predictions

将最后一个时间步输入transformer,再将预测出的数据作为transformer下一个时间步的输入,典型的时间预测