import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torch.nn import TransformerEncoder, TransformerEncoderLayer
import numpy as np

class PositionalEncoding(nn.Module):
    def __init__(self, L, device):
        super(PositionalEncoding, self).__init__()
        self.L = L
        self.device = device

    def forward(self, x):
        x = x.unsqueeze(-1)
        positions = torch.arange(self.L, dtype=torch.float).unsqueeze(0)
        frequencies = 2 ** (positions / self.L) * math.pi
        frequencies = frequencies.to(self.device)
        encoding = torch.cat([torch.sin(frequencies * x), torch.cos(frequencies * x)], dim=-1).to(self.device)
        return encoding


class TeacherForcingScheduler:
    def __init__(self, p_start=1.0, p_end=0.1, total_steps=50000):
        self.p_start, self.p_end, self.total_steps = p_start, p_end, max(1, total_steps)
    def prob(self, step):
        t = min(step / self.total_steps, 1.0)
        return float(self.p_start + (self.p_end - self.p_start) * t)


class LambdaCosineScheduler:
    def __init__(self, lam_start=0.0, lam_end=1.0, total_steps=50000):
        self.s, self.e, self.N = lam_start, lam_end, max(1, total_steps)
    def value(self, step):
        t = min(step / self.N, 1.0)
        return float(self.s + 0.5*(self.e - self.s) * (1 - math.cos(math.pi * t)))


def gaussian_nll(mu, log_var, target, reduce='mean', eps=1e-6):
    inv_var = torch.exp(-log_var).clamp(min=eps)
    nll = 0.5 * (inv_var * (target - mu)**2 + log_var + math.log(2*math.pi))
    if reduce == 'mean':  return nll.mean()
    if reduce == 'sum':   return nll.sum()
    return nll


def var_to_weight(log_var, alpha=1.0, eps=1e-6):
    var = torch.exp(log_var).detach()
    w = 1.0 / (1.0 + alpha * var)
    return w.clamp(min=eps, max=1.0)



class TransformerLoc(nn.Module):
    def __init__(self, cfg):
        super(TransformerLoc, self).__init__()

        self.name = 'TransformerLoc'
        
        self.seq_len = cfg.seq_len
        self.input_feature_num = cfg.input_feature_num
        self.d_model = cfg.model_dim
        self.feedforward_dim = cfg.feedforward_dim
        self.nhead = cfg.n_heads
        self.num_encoder_layers = cfg.n_layers
        
        self.mlp_input_feature_num = self.d_model
        self.fc_hidden_num = cfg.fc_hidden_num
        self.device = cfg.device
        self.return_cls = cfg.return_cls

        self.cls_token = nn.Parameter(torch.zeros(1, 1, 1, self.d_model))
        self.cls_pos_embed = nn.Parameter(torch.zeros(1, 1, 1, self.d_model))

        self.linear_proj = nn.Linear(self.input_feature_num, self.d_model)

        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(self.d_model, self.nhead, self.feedforward_dim, dropout=cfg.dropout_p, batch_first=True),
            num_layers=self.num_encoder_layers
        )

        self.coords_pos_encoder = PositionalEncoding(self.d_model//4, device=self.device)
        
        self.scale_factor = torch.tensor(1e-2)

        self.regressor = nn.ModuleList([
            nn.Sequential(
                nn.Linear(self.mlp_input_feature_num, self.fc_hidden_num),
                nn.LeakyReLU(),
                nn.Dropout(cfg.fc_dropout_p)
            ),
            nn.Linear(self.fc_hidden_num, self.input_feature_num)
        ])
    
    def forward(self, ap_coords, src):
        ap_coords = ap_coords * self.scale_factor        # Unit change：cm->m
        
        batch_size, ensemble_num, seq_ap_num, input_feature_num = src.shape

        resize_src = src.reshape(batch_size * ensemble_num * seq_ap_num, -1)
        src_proj = self.linear_proj(resize_src)
        src_proj = src_proj.reshape(batch_size, ensemble_num, seq_ap_num, -1)
        resize_src_proj = src_proj.reshape(batch_size * ensemble_num, seq_ap_num, -1)

        # ap_coords shape: (batch_size, ensemble_num, ap_num, 2)--位置编码
        ap_coords = ap_coords.reshape(batch_size * ensemble_num, seq_ap_num, -1)
        pe_coords_x = self.coords_pos_encoder(ap_coords[:, :, 0])
        pe_coords_y = self.coords_pos_encoder(ap_coords[:, :, 1])
        pe_coords = torch.cat((pe_coords_x, pe_coords_y), dim=2)  # shape: (batch_size*ensemble_num, ap_num, model_dim)

        inputs_token = resize_src_proj + pe_coords
        inputs_token = inputs_token.reshape(batch_size, ensemble_num, seq_ap_num, -1)
        
        cls_tokens = self.cls_token + self.cls_pos_embed
        cls_tokens = cls_tokens.expand(batch_size, ensemble_num, -1, -1)

        inputs_token = torch.cat((cls_tokens, inputs_token), dim=2)

        inputs_token = inputs_token.reshape(batch_size*ensemble_num, seq_ap_num+1, -1)

        # Pass through the Transformer encoder
        outputs = self.encoder(inputs_token)

        if self.return_cls:
            return outputs[:, 0, :]
        else:
            return outputs[:, 1:, :]
        

class SinTimePos(nn.Module):
    """标准正弦时间位置编码，直接加到 token 上"""
    def __init__(self, d_model, device, max_len=2048):
        super().__init__()
        pe = torch.zeros(max_len, d_model, device=device)
        pos = torch.arange(0, max_len, device=device).float().unsqueeze(1)
        div = torch.exp(torch.arange(0, d_model, 2, device=device).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        self.register_buffer('pe', pe.unsqueeze(0))  # [1, L, d]

    def forward(self, x):  # x: [B, L, d]
        L = x.size(1)
        return x + self.pe[:, :L]

        
class Transformergpt(nn.Module):
    def __init__(self, cfg):
        super(Transformergpt, self).__init__()

        self.seq_len = cfg.seq_len
        self.input_feature_num = cfg.input_feature_num
        self.d_model = cfg.model_dim
        self.feedforward_dim = cfg.feedforward_dim
        self.nhead = cfg.n_heads
        self.num_encoder_layers = cfg.n_layers
        
        self.mlp_input_feature_num = self.d_model
        self.fc_hidden_num = cfg.fc_hidden_num
        self.device = cfg.device

        self.linear_proj = nn.Linear(self.input_feature_num, self.d_model)
        dec_layer = nn.TransformerDecoderLayer(
            d_model=self.d_model, nhead=self.nhead, dim_feedforward=self.feedforward_dim,
            dropout=cfg.dropout_p, batch_first=True
        )
        self.time_decoder = nn.TransformerDecoder(dec_layer, num_layers=self.num_encoder_layers)

        # 学习型时间位置编码（供时间解码器使用）
        self.time_pos = SinTimePos(self.d_model, device=self.device, max_len=1024)


    @staticmethod
    def causal_mask(L, device):
        # TransformerDecoder 期望 True=masked
        return torch.triu(torch.ones(L, L, dtype=torch.bool, device=device), diagonal=1)
    
    def project_frames(self, frames):  # frames: [B*E*A, L, D_in]
        return self.time_pos(self.linear_proj(frames))
    
    class NTPState:
        def __init__(self, mem_proj):
            self.mem_proj = mem_proj

    def ntp_step_init(self, first_frame):
        """
        first_frame: [B, E, A, D_in]  -> state with mem=[x0]
        """
        B, E, A, D = first_frame.shape
        cur = first_frame.reshape(B*E*A, 1, D)   # [B*E*A,1,D]
        mem = self.project_frames(cur)                            # [B*E*A,1,d]
        return Transformergpt.NTPState(mem)
    
    def ntp_step_once(self, state):
        """
        输入：state.mem_proj = [B*E*A, L, d]
        目标：预测 t+1。用最后一个 token 作为 tgt（query），memory 用整段 mem_proj。
        返回：
          pred_feat: [B, E, A, D_in] （单步预测）
          new_state: 增加一个“预测帧”到 mem_proj（供下个 step）
        """
        mem = state.mem_proj
        tgt = mem[:, -1:, :]
        dec = self.time_decoder(tgt=tgt, memory=mem)

        return dec.squeeze(1) 


class MLMHeadWithUncertainty(nn.Module):
    def __init__(self, cfg):
        super(MLMHeadWithUncertainty, self).__init__()

        self.mlp_input_feature_num = cfg.model_dim
        self.fc_hidden_num = cfg.fc_hidden_num
        self.input_feature_num = cfg.input_feature_num

        self.ff = nn.Sequential(
            nn.Linear(self.mlp_input_feature_num, self.fc_hidden_num),
            nn.LeakyReLU(),
            nn.Dropout(cfg.fc_dropout_p)
        )
        self.mu = nn.Linear(self.fc_hidden_num, self.input_feature_num)
        self.log_var = nn.Linear(self.fc_hidden_num, self.input_feature_num)
    def forward(self, h):  # h: [..., d_model]
        z = self.ff(h)
        mu = self.mu(z)
        log_var = self.log_var(z)
        log_var = log_var.clamp(min=-2.0, max=5.0) 
        return mu, log_var


class PretrainBertGpt(nn.Module):
    def __init__(self,cfg):
        super(PretrainBertGpt, self).__init__()

        self.name = 'Pretrain_bert_gpt'
        self.seq_len = cfg.seq_len
        self.input_feature_num = cfg.input_feature_num
        self.d_model = cfg.model_dim
        self.feedforward_dim = cfg.feedforward_dim
        self.nhead = cfg.n_heads
        self.num_encoder_layers = cfg.n_layers
        
        self.mlp_input_feature_num = self.d_model
        self.fc_hidden_num = cfg.fc_hidden_num
        self.device = cfg.device

        self.bert_encoder = TransformerLoc(cfg)
        self.gpt_decoder = Transformergpt(cfg)

        self.mlm_head = MLMHeadWithUncertainty(cfg)

        self.bigru = nn.GRU(
            input_size=self.mlp_input_feature_num,
            hidden_size=self.fc_hidden_num,
            num_layers=1,
            batch_first=True,
            bidirectional=True
        )

        self.ntp_head = nn.Sequential(
            nn.Linear(self.fc_hidden_num * 2, self.fc_hidden_num),
            nn.LeakyReLU(),
            nn.Dropout(cfg.fc_dropout_p),
            nn.Linear(self.fc_hidden_num, self.input_feature_num)
        )


    def ntp_decode_1step(self, src):

        batch_size, frame_num, ensemble_num, seq_ap_num, input_feature_num = src.shape
        resize_src = src.permute(0, 2, 3, 1, 4).contiguous()
        resize_src = resize_src.reshape(batch_size * ensemble_num * seq_ap_num, frame_num, -1)
        resize_src_proj = self.gpt_decoder.project_frames(resize_src)

        mem_src = resize_src_proj
        tgt_src = resize_src_proj[:, :-1, :]

        # 因果 mask（只作用在 tgt 内部的自注意力）
        tgt_mask = self.gpt_decoder.causal_mask(tgt_src.size(1), device=src.device)  # [T-1, T-1]

        # 时间解码
        dec_out = self.gpt_decoder.time_decoder(tgt=tgt_src, memory=mem_src, tgt_mask=tgt_mask)

        return dec_out

    def ntp_rollout_k(self, start_seq, k_steps):
        """
        start_seq: [B,t0,E,A,D] 作为rollout起点（一般 t0=1 或更长）
        返回：
          preds_k:  [B,k,E,A,D]
        实现要点：
          - 维护 memory_proj（线性+PE 后）在 state 中，只追加新增帧
          - 每步仅对最后一 token 做一次解码（tgt=last token）
        """
        B, t0 ,E, A, D = start_seq.shape
        # 初始化 state
        first = start_seq[:, 0, :, :, :]                 # [B,E,A,D]
        state = self.gpt_decoder.ntp_step_init(first)            # mem: [B*E*A,1,d]
        preds = []
        # 若 t0 > 1，先把已有帧全部“灌入”到 state 中（不产生输出）
        if t0 > 1:
            for idx in range(1, t0):
                # 将已有帧映射并拼接到 memory（作为“已知上下文”）
                nxt = start_seq[:, idx, :, :, :]         # [B,E,A,D]
                mem_add = self.gpt_decoder.project_frames(nxt.reshape(B*E*A,1,D))
                state.mem_proj = torch.cat([state.mem_proj, mem_add], dim=1)  # 仅扩展 memory

        # rollout k 步：每步
        # (1) 用最后一 token 作为 tgt 做一次解码，得到 d_model 表示
        # (2) 经 GRU+FC -> 预测下一帧特征 pred_t
        # (3) 把 pred_t 投影+加PE后 append 到 memory，用于下一步
        for _ in range(k_steps):
            dec_last = self.gpt_decoder.ntp_step_once(state)     # [B*E*A,d]
            # 1步 GRU 读出
            h, _ = self.bigru(dec_last.unsqueeze(1))   # [B*E*A,1,hid]
            pred_flat = self.ntp_head(h)                   # [B*E*A,1,D]
            pred = pred_flat.reshape(B, E, A, 1, D).permute(0,3,1,2,4).contiguous()  # [B,1,E,A,D]
            preds.append(pred)

            # append 预测到 memory
            mem_add = self.gpt_decoder.project_frames(pred_flat) # [B*E*A,1,d]
            state.mem_proj = torch.cat([state.mem_proj, mem_add], dim=1)

        return torch.cat(preds, dim=1)  # [B,k,E,A,D]

    def forward(self, ap_coords, src, 
                mask_pos_batch, 
                mask_label_batch,
                original_src,
                step,
                tf_sched: TeacherForcingScheduler,
                lam_sched: LambdaCosineScheduler,
                cfg_train):

        batch_size, frame_num, ensemble_num, seq_ap_num, input_feature_num = src.shape

        mu_frames = []
        logvar_frames = []
        encoded_frames = []
        for i in range(frame_num):
            enc_t = self.bert_encoder(ap_coords[:, i], src[:,i])
            mu_t, logvar_t = self.mlm_head(enc_t)

            enc_t = enc_t.reshape(batch_size, ensemble_num, seq_ap_num, -1)
            mu_t = mu_t.reshape(batch_size, ensemble_num, seq_ap_num, -1)
            logvar_t = logvar_t.reshape(batch_size, ensemble_num, seq_ap_num, -1)

            mu_frames.append(mu_t)
            logvar_frames.append(logvar_t)
            encoded_frames.append(enc_t)
            

        # 堆叠回时序
        recon_seq = torch.stack(mu_frames, dim=1)
        logvar_seq = torch.stack(logvar_frames, dim=1)

        # === MLM loss（仅 masked） ===
        if mask_pos_batch.numel() > 0:
            t_idx = mask_pos_batch[..., 0].to(self.device)
            i_idx = mask_pos_batch[..., 1].to(self.device)
            j_idx = mask_pos_batch[..., 2].to(self.device)
            b_idx = torch.arange(batch_size, device=self.device).unsqueeze(1).expand_as(t_idx)
            pred_masked = recon_seq[b_idx, t_idx, i_idx, j_idx, :]   # [B,N_mask,D]
            logv_masked = logvar_seq[b_idx, t_idx, i_idx, j_idx, :]
            target_mask = mask_label_batch.to(self.device)
            mlm_loss = gaussian_nll(pred_masked, logv_masked, target_mask, reduce='mean')
        else:
            mlm_loss = torch.tensor(0.0, device=self.device)


        # === NTP 输入（teacher forcing + detach + noise） ===
        p_tf = tf_sched.prob(step)
        bern = torch.rand(size=(batch_size, frame_num, 1, 1, 1), device = self.device)
        use_gt = (bern < p_tf).float()
        mixed = use_gt * src + (1-use_gt) * recon_seq

        if cfg_train.get('detach_ntp_input', True):
            mixed = mixed.detach()
        if cfg_train.get('noise_sigma', 0.0) > 0:
            mixed = mixed + torch.randn_like(mixed) * cfg_train['noise_sigma']

        # === 1-step NTP ===
        dec_out = self.ntp_decode_1step(mixed)
        bigru_out, _ = self.bigru(dec_out)
        pred_next = self.ntp_head(bigru_out).reshape(batch_size, ensemble_num, seq_ap_num, frame_num-1, -1)
        pred_next = pred_next.permute(0, 3, 1, 2, 4).contiguous() 

        target_next = original_src[:, 1:]
        w = var_to_weight(logvar_seq[:, :-1], alpha=cfg_train.get('alpha_var_weight'))
        ntp_loss = ((pred_next - target_next)**2 * w).mean()

        # === 小概率触发 k-step rollout（高效单步 API） ===
        ntp_loss_k = torch.tensor(0.0, device=self.device)
        rollout_ratio = cfg_train.get('k_step_ratio')
        if torch.rand(1).item() < rollout_ratio:
            k = min(cfg_train.get('k_max'), frame_num-1)
            # 起点用 mixed 的前 1 帧（或前 t0=3 帧）
            start_len = min(3, frame_num-1)
            start_seq = mixed[:, :start_len]   # [B,1,E,A,D]
            preds_k = self.ntp_rollout_k(start_seq, k_steps=k)  # [B,k,E,A,D]
            tgt_k = original_src[:, 1:1+k]
            ntp_loss_k = F.mse_loss(preds_k, tgt_k)

        # === 组合损失（λ 调度） ===
        lam = lam_sched.value(step)
        total = mlm_loss + lam * (0.7 * ntp_loss + 0.3 * ntp_loss_k)

        # print("mlm_loss:", mlm_loss.item(),
        # "mean(log_var):", logvar_seq.mean().item(),
        # "min(log_var):", logvar_seq.min().item())

        return {
            'total_loss': total,
            'mlm_loss': mlm_loss,
            'ntp_loss': ntp_loss,
            'ntp_loss_k': ntp_loss_k,
            'p_tf': p_tf,
            'lambda': lam
        }