import torch
import math
from stork.nodes.base import CellGroup


class custom_InputGroup(CellGroup):
    """A special group which is used to supply batched dense tensor input to the network via its feed_data function."""

    def __init__(self, shape, name="Input", output_feedback=False,
                 teacher_forcing_ratio_start=1.0, teacher_forcing_ratio_end=0.0,
                 max_epochs=100, **kwargs):
        super(custom_InputGroup, self).__init__(shape, name=name, **kwargs)
        self.output_feedback = output_feedback
        if output_feedback:
            # 混合自回归训练相关参数
            self.teacher_forcing_ratio_start = teacher_forcing_ratio_start  # 初始时使用ground truth的比例
            self.teacher_forcing_ratio_end = teacher_forcing_ratio_end  # 最终时使用ground truth的比例
            self.max_epochs = max_epochs  # 总训练轮数
            self.current_epoch = 0  # 当前训练轮数
            self.teacher_forcing_ratio = teacher_forcing_ratio_start  # 当前使用ground truth的比例

    def add_src(self, src):
        """Add a source to the input group."""
        self.src = src

    # def set_epoch(self, epoch):
    #     """设置当前训练轮数并更新teacher forcing比例"""
    #     self.current_epoch = epoch
    #     # 线性衰减teacher forcing比例
    #     self.teacher_forcing_ratio = self.teacher_forcing_ratio_start - (
    #         (self.teacher_forcing_ratio_start - self.teacher_forcing_ratio_end) *
    #         min(1.0, epoch / self.max_epochs)
    #     )

    def set_epoch(self, epoch):
        """设置当前训练轮数并更新teacher forcing比例"""
        self.current_epoch = epoch
        # 余弦退火衰减teacher forcing比例
        self.teacher_forcing_ratio = self.teacher_forcing_ratio_end + (
                (self.teacher_forcing_ratio_start - self.teacher_forcing_ratio_end) *
                (1 + math.cos(math.pi * epoch / self.max_epochs)) / 2
        )

    def set_epoch_step(self, step, total_stepNum):
        """设置当前训练步数并更新teacher forcing比例"""
        self.current_step = step
        # 阶梯退火衰减teacher forcing比例
        self.teacher_forcing_ratio = self.teacher_forcing_ratio_end + (
                (self.teacher_forcing_ratio_start - self.teacher_forcing_ratio_end) *
                (1 + math.cos(math.pi * step / total_stepNum)) / 2
        )
        print(self.teacher_forcing_ratio)




    def reset_state(self, batch_size=None):
        super().reset_state(batch_size)
        self.out = self.states["out"] = torch.zeros(self.int_shape, device=self.device, dtype=self.dtype)

    def feed_data(self, data):
        if self.output_feedback:
            shape = self.shape
            shape = (shape[0], shape[1] - 2)
            self.local_data = data.reshape((data.shape[:2] + shape)).to(self.device)
        else:
            self.local_data = data.reshape((data.shape[:2] + self.shape)).to(self.device)

    def forward(self):
        if self.output_feedback:
            batchsize = self.local_data[:, self.clk].shape[0]

            if self.clk != 0:
                feedback = self.src.states["out"]
            else:
                feedback = torch.zeros(batchsize, 2, device=self.device, dtype=self.dtype)

            feedback = feedback.reshape((batchsize, self.shape[0], 2))

            if self.training:
                # 训练模式：根据teacher forcing比例混合使用ground truth和实际输出
                use_ground_truth = torch.rand(batchsize, 1, 1, device=self.device) < self.teacher_forcing_ratio
                # 扩展为与feedback相同的形状
                use_ground_truth = use_ground_truth.expand_as(feedback)
                # 假设ground truth为全0（根据实际情况替换）
                ground_truth = self.local_data[:, self.clk,:,-2:]
                # 混合使用ground truth和实际反馈
                mixed_feedback = torch.where(use_ground_truth, ground_truth, feedback)
                self.out = self.states["out"] = torch.cat((self.local_data[:, self.clk, :, :-2], mixed_feedback), dim=2)
            else:
                # 测试模式：直接使用模型输出作为反馈
                self.out = self.states["out"] = torch.cat((self.local_data[:, self.clk], feedback), dim=2)
        else:
            self.out = self.states["out"] = self.local_data[:, self.clk]

# 输出等于输入（的一部分）
class custom_fake_InputGroup(CellGroup):
    """A special group which is used to supply batched dense tensor input to the network via its feed_data function."""

    def __init__(self, shape, name="fake_Input", **kwargs):
        super(custom_fake_InputGroup, self).__init__(shape, name=name, **kwargs)

    def reset_state(self, batch_size=None):
        super().reset_state(batch_size)
        self.out = self.states["out"] = torch.zeros(self.int_shape, device=self.device, dtype=self.dtype)

    def forward(self):
        self.out = self.states["out"] = self.states["input"]

# 用于feedback的输入组
class custom_feedback_InputGroup(CellGroup):
    """A special group which is used to supply batched dense tensor input to the network via its feed_data function."""

    def __init__(self, shape, name="fake_Input", output_feedback_timestep=1, **kwargs):
        super(custom_feedback_InputGroup, self).__init__(shape, name=name, **kwargs)
        self.output_feedback_timestep = output_feedback_timestep

    def reset_state(self, batch_size=None):
        super().reset_state(batch_size)
        self.out = self.states["out"] = torch.zeros(self.int_shape, device=self.device, dtype=self.dtype)

    def forward(self):

        # 将现有数据向前移动一行（丢弃第一行）
        self.out = torch.cat((self.out[:,1:],self.states["input"][:,0,:].unsqueeze(1)),dim=1)

        # 更新states字典
        self.states["out"] = self.out
        #
        # self.out = self.states["out"] = self.states[input]
