import numpy as np

import torch
import torch.nn as nn
from .context_encoder import build_context_encoder
from .motion_decoder import build_decoder
from .motion_decoder.mtr_decoder import modulate
from .utils.common_layers import build_mlps
from einops import repeat, rearrange
from models.context_encoder.mtr_encoder import SinusoidalPosEmb
from models.context_encoder.condition_encoder import ZEncoder, ZFiLM
from models.neural_sde.ctr_sde import simulate_sde_paths, ControlledSSLSDE
from models.neural_sde.z0_encoder import Z0Encoder


class MotionTransformer(nn.Module):
    def __init__(self, model_config, logger, config):
        super().__init__()
        self.model_cfg = model_config
        self.dim = self.model_cfg.CONTEXT_ENCODER.D_MODEL 
        self.config = config
        self.logger = logger
        use_pre_norm = self.model_cfg.get('USE_PRE_NORM', False)
        self.ablation_mode = self.config.CNSDE
        assert not use_pre_norm, "Pre-norm is not supported in this model"
        self.T_f = self.config.get('future_frames', 0)
        self.dt = self.config.get('dt', 0)
        self.agent_dim = self.model_cfg.AGENT_DIM 
        self.context_encoder = build_context_encoder(self.model_cfg.CONTEXT_ENCODER, use_pre_norm, config.device)

        if self.ablation_mode == "m0":
            in_fuse = (self.dim            # encoder_out
            +  (self.dim*1)      # time_dim
            +  self.dim)          # y_emb
            # +  self.dim)         # cond_vec
        
        if self.ablation_mode == "m1":
            self.z_encoder = ZEncoder(
                d_hist = self.model_cfg.get('COND_D_HIST', 0),
                d_cue = self.model_cfg.get('COND_D_CUE', 0),
                d_model = self.dim,
                d_z = self.model_cfg.get('COG_D_Z', 0)
            )

            self.cond_proj = nn.Linear(self.dim, self.dim)
            self.z_proj = nn.Linear(self.model_cfg.get('COG_D_Z', 0), self.dim)
            self.z_film = ZFiLM(d_feat=self.dim)
            self.z_gamma = nn.Linear(self.dim, self.dim)
            self.z_beta = nn.Linear(self.dim, self.dim)

            nn.init.zeros_(self.z_gamma.weight)
            nn.init.zeros_(self.z_gamma.bias)
            nn.init.zeros_(self.z_beta.weight)
            nn.init.zeros_(self.z_beta.bias)

            self.cond_film_gamma = nn.Linear(self.dim, self.dim)
            self.cond_film_beta  = nn.Linear(self.dim, self.dim)

            in_fuse = (self.dim            # encoder_out
                    +  (self.dim*1)      # time_dim
                    +  self.dim          # y_emb
                    +  self.dim)         # cond_vec
        
        if self.ablation_mode == "m2":
            self.z0_encoder = Z0Encoder(
                num_keypoints=self.config.agents,
                kp_dim=self.agent_dim * 3,
                stim_dim=self.model_cfg.get('COND_D_CUE', 0),
                hidden_dim=self.dim,
                z_dim=self.model_cfg.get('COG_D_Z', 0)
            )

            self.neural_sde = ControlledSSLSDE(
                z_dim=self.model_cfg.get('COG_D_Z', 0),
                stim_dim=self.model_cfg.get('COND_D_CUE', 0),
                num_regimes=3,
                num_bases=16,
                hidden_dim=self.dim,
                init_scale=0.1,
                dataset_type=self.model_cfg.CONTEXT_ENCODER.DATA_TYPE
            )
            self.z_seq_proj =  nn.Linear(self.model_cfg.get('COG_D_Z', 0), self.dim)
            self.z_seq_gamma = nn.Linear(self.dim, self.dim)
            self.z_seq_beta = nn.Linear(self.dim, self.dim)

            in_fuse = (self.dim            # encoder_out
            + (self.dim*1)      # time_dim
            + self.dim)          # y_emb
        #    + self.dim)         # cond_vec
                
        self.motion_query_embedding = nn.Embedding(self.model_cfg.NUM_PROPOSED_QUERY, self.dim)
        self.agent_order_embedding = nn.Embedding(self.model_cfg.CONTEXT_ENCODER.NUM_OF_ATTN_NEIGHBORS, self.dim)
        self.post_pe_cat_mlp = nn.Sequential(
            nn.Linear(self.dim, self.dim),
            nn.LayerNorm(self.dim),
            nn.ReLU(),
            nn.Linear(self.dim, self.dim),
        )

        time_dim = self.dim * 1
        sinu_pos_emb = SinusoidalPosEmb(self.dim, theta = 10000)

        self.time_mlp = nn.Sequential(
            sinu_pos_emb,
            nn.Linear(self.dim, time_dim),
            nn.ReLU(),
            nn.Linear(time_dim, time_dim),
        )

        self.noisy_y_mlp = nn.Sequential(
            nn.Linear(self.model_cfg.MODEL_OUT_DIM, self.dim),
            nn.ReLU(),
            nn.Linear(self.dim, self.dim),
            nn.ReLU(),
            nn.Linear(self.dim, self.dim),
        )

        dropout_ = self.model_cfg.MOTION_DECODER.DROPOUT_OF_ATTN
        self.noisy_y_attn_k = nn.TransformerEncoderLayer(d_model=self.dim, nhead=4, dim_feedforward=self.dim * 4, dropout=dropout_, batch_first=True)
        self.noisy_y_attn_a = nn.TransformerEncoderLayer(d_model=self.dim, nhead=4, dim_feedforward=self.dim * 4, dropout=dropout_, batch_first=True)

        dim_decoder = self.model_cfg.MOTION_DECODER.D_MODEL

        self.init_emb_fusion_mlp = nn.Sequential(
            nn.Linear(in_fuse, self.dim),
            nn.LayerNorm(self.dim),
            nn.ReLU(),
            nn.Linear(self.dim, dim_decoder),
        )

        self.readout_mlp = nn.Sequential(
            nn.Linear(dim_decoder, dim_decoder),
            nn.ReLU(),
            nn.Linear(dim_decoder, self.model_cfg.MODEL_OUT_DIM),
        )

        self.motion_decoder = build_decoder(self.model_cfg.MOTION_DECODER, use_pre_norm)

        self.reg_head = build_mlps(c_in=self.dim, mlp_channels=self.model_cfg.REGRESSION_MLPS, ret_before_act=True, without_norm=True)
        self.cls_head = build_mlps(c_in=dim_decoder, mlp_channels=self.model_cfg.CLASSIFICATION_MLPS, ret_before_act=True, without_norm=True)
 
        params_encoder = sum(p.numel() for p in self.context_encoder.parameters())
        params_decoder = sum(p.numel() for p in self.motion_decoder.parameters())
        params_total = sum(p.numel() for p in self.parameters())
        params_other = params_total - params_encoder - params_decoder
        logger.info("Total parameters: {:,}, Encoder: {:,}, Decoder: {:,}, Other: {:,}".format(
            params_total, params_encoder, params_decoder, params_other
        ))

    def _build_future_control_seq(
            self,
            x_data,
            B: int,
            device: torch.device,
            dtype: torch.dtype,
    ) -> torch.Tensor:
        T_f = self.T_f

        # 情况 1：有明确的未来刺激序列
        if "fut_cond_cue" in x_data:
            # print("--- found fut_cond_cue")
            fut_stim = x_data["fut_cond_cue"]  # [B, T_f, stim_dim]
            # print("fut_cond_cue shape = {} {}".format(fut_stim.shape, T_f))
            assert fut_stim.shape[0] == B
            assert fut_stim.shape[1] == T_f
            assert fut_stim.shape[2] == self.neural_sde.stim_dim
            u_seq = fut_stim.to(device=device, dtype=dtype)
        else:
            hist_stim = x_data["hist_stim"]  # [B, Th, stim_dim]
            u_last = hist_stim[:, -1, :]  # [B, stim_dim]
            u_seq = u_last.unsqueeze(1).repeat(1, T_f, 1)  # [B, T_f, stim_dim]

        return u_seq

    def apply_PE(self, y_emb, k_pe_batch, a_pe_batch):
        if self.model_cfg.get('USE_PE_QUERY', True) and self.model_cfg.get('USE_PE_AGENT', True):
            y_emb = y_emb + k_pe_batch + a_pe_batch
        elif self.model_cfg.get('USE_PE_QUERY', True):
            y_emb = y_emb + k_pe_batch
        elif self.model_cfg.get('USE_PE_AGENT', True):
            y_emb = y_emb + a_pe_batch
        else:
            pass
        return y_emb
    
    def get_z_rollout(self, x_data):
        past_traj = x_data["past_traj"]
        hist_stim = x_data["hist_cond_cue"]
        z0 = self.z0_encoder(past_traj, hist_stim)   # [B, z_dim]

        u_seq = self._build_future_control_seq(
            x_data=x_data,
            B=z0.shape[0],
            device=z0.device,
            dtype=z0.dtype,
        )                                          # [B, T_f, stim_dim]

        z_seq = simulate_sde_paths(
            sde=self.neural_sde,
            z0=z0,
            u_seq=u_seq,
            dt=self.dt,
        )                                          # [B, T_f, d_dim]
        return z_seq, u_seq
        
    def forward(self, y, time, x_data):
        device = y.device
        B, K, A, _ = y.shape

        encoder_out = self.context_encoder(x_data['past_traj_original_scale'])  # [B, A, D]
        encoder_out_batch = repeat(encoder_out, 'b a d -> b k a d', k=K, a=A) 	# [B, K, A, D]

        y_emb = self.noisy_y_mlp(y)  	# [B, K, A, D]

        if self.ablation_mode == "m1":
            cond_flow, z = self.z_encoder(x_data)  # Stage A
            cond_flow = self.cond_proj(cond_flow)
            cond_bka = repeat(cond_flow, 'b d -> b k a d', k=K, a=A)
            z_feat = self.z_proj(z)
            z_bka = z_feat[:, None, None, :].expand(B, K, A, -1)      # [B,K,A,d_model]
            gamma, beta = self.cond_film_gamma(cond_bka), self.cond_film_beta(cond_bka)
            y_emb = gamma * y_emb + beta    
            
        if self.ablation_mode == "m2":
            z_seq, u_seq = self.get_z_rollout(x_data)
            z_frame = self.z_seq_proj(z_seq)
            z_frame_bka = z_frame[:, None, None, :, :].expand(B, K, A, self.T_f, self.dim)

            gamma = self.z_seq_gamma(z_frame_bka)  # [B,K,A,T_f,D]
            beta = self.z_seq_beta(z_frame_bka)  # [B,K,A,T_f,D]
        
        time_ = time
        if self.config.denoising_method == 'fm':
            time = time * 1000.0  # flow matching time upscaling

        t_emb = self.time_mlp(time) 	# [B, D]
        t_emb_batch = repeat(t_emb, 'b d -> b k a d', b=B, k=K, a=A) # [B, K, A, D]  

        k_pe = self.motion_query_embedding(torch.arange(self.model_cfg.NUM_PROPOSED_QUERY, device=device))	# [K, D]
        k_pe_batch = repeat(k_pe, 'k d -> b k a d', b=B, a=A)	# [B, K, A, D]

        a_pe = self.agent_order_embedding(torch.arange(self.model_cfg.CONTEXT_ENCODER.NUM_OF_ATTN_NEIGHBORS, device=device))  # [A, D]
        a_pe_batch = repeat(a_pe, 'a d -> b k a d', b=B, k=K)	# [B, K, A, D]

        y_emb_k = rearrange(self.apply_PE(y_emb, k_pe_batch, a_pe_batch), 'b k a d -> (b a) k d')
        y_emb_k = self.noisy_y_attn_k(y_emb_k)
        y_emb = rearrange(y_emb_k, '(b a) k d -> b k a d', b=B, a=A)

        y_emb_a = rearrange(y_emb, 'b k a d -> (b k) a d')
        y_emb_a = self.noisy_y_attn_a(y_emb_a)
        y_emb = rearrange(y_emb_a, '(b k) a d -> b k a d', b=B, k=K)

        if self.training and self.config.get('drop_method', None) == 'emb':
            assert self.config.get('drop_logi_k', None) is not None and self.config.get('drop_logi_m', None) is not None
            m, k = self.config.drop_logi_m, self.config.drop_logi_k
            p_m = 1 / (1 + torch.exp(-k * (time_ - m)))
            p_m = p_m[:, None, None, None]
            y_emb = y_emb.masked_fill(torch.rand_like(p_m) < p_m, 0.)

        # update
        if self.ablation_mode == "m1":
            emb_in = torch.cat((encoder_out_batch, y_emb, t_emb_batch, cond_bka), dim=-1)
        else:
            emb_in = torch.cat((encoder_out_batch, y_emb, t_emb_batch), dim=-1)
        emb_fusion = self.init_emb_fusion_mlp(emb_in)	 	# [B, K, A, D]
        
        if self.ablation_mode == "m2":
            emb_fusion = emb_fusion.unsqueeze(3).repeat(1, 1, 1, self.T_f, 1) # [B, K, A, T, D]
            emb_fusion = emb_fusion * (1 + gamma) + beta        # [B, K, A, T, D]
            a_pe_batch = a_pe_batch.unsqueeze(3).repeat(1, 1, 1, self.T_f, 1)
            k_pe_batch = k_pe_batch.unsqueeze(3).repeat(1, 1, 1, self.T_f, 1)
            t_emb_batch = t_emb_batch.unsqueeze(3).repeat(1, 1, 1, self.T_f, 1)

        
        query_token = self.post_pe_cat_mlp(self.apply_PE(emb_fusion, k_pe_batch, a_pe_batch)) 								# [B, K, A, D]
        readout_token = self.motion_decoder(query_token, t_emb)													# [B, K, A, D]

        denoiser_x = self.reg_head(readout_token)  										# [B, K, A, F * D]
        if self.ablation_mode == "m2":
            denoiser_x = rearrange(denoiser_x, 'b k a t d -> b k a (t d)')

        if self.config.LOSS_CTRL or self.config.LOSS_STAB:
            return denoiser_x, (z_seq, u_seq)
        else:
            return denoiser_x
        

class IMLETransformer(nn.Module):
    def __init__(self, model_config, logger, config):
        super().__init__()
        self.model_cfg = model_config
        self.dim = self.model_cfg.CONTEXT_ENCODER.D_MODEL
        self.cfg = config

        self.objective = self.cfg.objective

        use_pre_norm = self.model_cfg.get('USE_PRE_NORM', False)

        assert not use_pre_norm, "Pre-norm is not supported in this model"
        self.T_f = self.cfg.get('future_frames', 0)
        self.dt = self.cfg.get('dt', 0)

        self.context_encoder = build_context_encoder(self.model_cfg.CONTEXT_ENCODER, use_pre_norm, config.device)

        self.z0_encoder = Z0Encoder(
            num_keypoints=self.cfg.agents,
            kp_dim=6,
            stim_dim=self.model_cfg.get('COND_D_CUE', 0),
            hidden_dim=self.dim,
            z_dim=self.model_cfg.get('COG_D_Z', 0)
        )

        self.neural_sde = ControlledSSLSDE(
            z_dim=self.model_cfg.get('COG_D_Z', 0),
            stim_dim=self.model_cfg.get('COND_D_CUE', 0),
            num_regimes=3,
            num_bases=16,
            hidden_dim=self.dim,
            init_scale=0.1,
            dataset_type=self.model_cfg.CONTEXT_ENCODER.DATA_TYPE
        )
        self.z_seq_proj =  nn.Linear(self.model_cfg.get('COG_D_Z', 0), self.dim)
        self.z_seq_gamma = nn.Linear(self.dim, self.dim)
        self.z_seq_beta = nn.Linear(self.dim, self.dim)

        if self.objective == 'set':
            self.motion_query_embedding = nn.Embedding(self.model_cfg.NUM_PROPOSED_QUERY, self.dim)

        self.agent_order_embedding = nn.Embedding(self.model_cfg.CONTEXT_ENCODER.NUM_OF_ATTN_NEIGHBORS, self.dim)
        
        self.noisy_vec_mlp = nn.Sequential(
            nn.Linear(self.dim, self.dim),
            nn.ReLU(),
            nn.Linear(self.dim, self.dim)
        )

        self.pe_mlp = nn.Sequential(
            nn.Linear(self.dim, self.dim),
            nn.ReLU(),
            nn.Linear(self.dim, self.dim),
        )

        dim_decoder = self.model_cfg.MOTION_DECODER.D_MODEL
        self.init_emb_fusion_mlp = nn.Sequential(
            nn.Linear(self.dim + self.dim, self.dim),
            nn.LayerNorm(self.dim),
            nn.ReLU(),
            nn.Linear(self.dim, dim_decoder),
        )
        
        self.readout_mlp = nn.Sequential(
            nn.Linear(dim_decoder, dim_decoder),
            nn.ReLU(),
            nn.Linear(dim_decoder, self.model_cfg.MODEL_OUT_DIM),
        )

        self.motion_decoder = build_decoder(self.model_cfg.MOTION_DECODER, use_pre_norm, use_adaln=False)

        self.reg_head = build_mlps(c_in=self.dim, mlp_channels=self.model_cfg.REGRESSION_MLPS, ret_before_act=True, without_norm=True)

        params_encoder = sum(p.numel() for p in self.context_encoder.parameters())
        params_decoder = sum(p.numel() for p in self.motion_decoder.parameters())
        params_total = sum(p.numel() for p in self.parameters())
        params_other = params_total - params_encoder - params_decoder
        logger.info("Total parameters: {:,}, Encoder: {:,}, Decoder: {:,}, Other: {:,}".format(params_total, params_encoder, params_decoder, params_other))

    def _build_future_control_seq(
            self,
            x_data,
            B: int,
            device: torch.device,
            dtype: torch.dtype,
    ) -> torch.Tensor:
        T_f = self.T_f

        if "fut_cond_cue" in x_data:
            # print("--- found fut_cond_cue")
            fut_stim = x_data["fut_cond_cue"]  # [B, T_f, stim_dim]
            # print("fut_cond_cue shape = {} {}".format(fut_stim.shape, T_f))
            assert fut_stim.shape[0] == B
            assert fut_stim.shape[1] == T_f
            assert fut_stim.shape[2] == self.neural_sde.stim_dim
            u_seq = fut_stim.to(device=device, dtype=dtype)
        else:
            assert "hist_stim" in x_data, "需要在 x_data 中提供 hist_stim 或 fut_stim"
            hist_stim = x_data["hist_stim"]  # [B, Th, stim_dim]
            u_last = hist_stim[:, -1, :]  # [B, stim_dim]
            u_seq = u_last.unsqueeze(1).repeat(1, T_f, 1)  # [B, T_f, stim_dim]

        return u_seq
    
    def forward(self, x_data, num_to_gen=None):
        device = x_data['past_traj_original_scale'].device
        B, A, T, _ = x_data['past_traj_original_scale'].shape
        K = self.cfg.denoising_head_preds
        D = self.dim

        if self.training:
            M = self.cfg.num_to_gen
        else:
            M = num_to_gen

        # 1) context encoder
        encoder_out = self.context_encoder(x_data['past_traj_original_scale'])  # [B, A, D]
        encoder_out_batch = repeat(encoder_out, 'b a d -> b k a d', k=K, a=A) 	# [B, K, A, D]

        # 2) z0 & SDE
        past_traj = x_data["past_traj"]          
        hist_stim = x_data["hist_cond_cue"]      # [B, Th, stim_dim]

        z0 = self.z0_encoder(past_traj, hist_stim)  # [B, z_dim]
        u_seq = self._build_future_control_seq(x_data, B, device, z0.dtype)   # [B, T_f, stim_dim]
        z_seq = simulate_sde_paths(self.neural_sde, z0, u_seq, dt=self.dt)    # [B, T_f, z_dim]

        z_frame = self.z_seq_proj(z_seq)         # [B, T_f, D]
        # 扩展到 [B, M, K, A, T_f, D]
        z_frame_bmkat = z_frame[:, None, None, None, :, :].expand(B, M, K, A, self.T_f, D)
        gamma = self.z_seq_gamma(z_frame_bmkat)  # [B,M,K,A,T_f,D]
        beta  = self.z_seq_beta(z_frame_bmkat)   # [B,M,K,A,T_f,D]

        # init noise embeddings
        noise = torch.randn((B, M, D), device=device)       # [B, M, D]
        noise_emb = self.noisy_vec_mlp(noise)  	            # [B, M, D]

        if self.cfg.objective == 'set':
            encoder_out_batch = repeat(encoder_out, 'b a d -> b m k a d', m=M, k=K, a=A)    # [B, M, K, A, D]

            k_pe = self.motion_query_embedding(torch.arange(K, device=device))	            # [K, D]
            k_pe_batch = repeat(k_pe, 'k d -> b m k a t d', b=B, m=M, a=A, t=self.T_f)	                # [B, M, K, A, D]

            a_pe = self.agent_order_embedding(torch.arange(A, device=device))               # [A, D]
            a_pe_batch = repeat(a_pe, 'a d -> b m k a t d', b=B, m=M, k=K, t=self.T_f)	                # [B, M, K, A, D]

            noise_emb_batch = repeat(noise_emb, 'b m d -> b m k a d', k=K, a=A)	            # [B, M, K, A, D]
        elif self.cfg.objective == 'single':
            raise NotImplementedError
        else:
            raise NotImplementedError

        # send to motion decoder
        emb_fusion = self.init_emb_fusion_mlp(torch.cat((encoder_out_batch, noise_emb_batch), dim=-1))	 	# [B, M, K, A, D]
         # 4) 扩展时间维并做 FiLM
        emb_fusion = emb_fusion.unsqueeze(-2).expand(B, M, K, A, self.T_f, D)   # [B,M,K,A,T_f,D]
        emb_fusion = emb_fusion * (1.0 + gamma) + beta                          

        query_token = self.pe_mlp(emb_fusion + k_pe_batch + a_pe_batch) 					                # [B, M, K, A, D]

        if self.cfg.objective == 'set':
            query_token = rearrange(query_token, 'b m k a t d -> (b m) k a t d')
            readout_token = self.motion_decoder(query_token)
            readout_token = rearrange(readout_token, '(b m) k a t d -> b m k a t d', m=M)
        elif self.cfg.objective == 'single':
            raise NotImplementedError
        else:
            raise NotImplementedError

        # readout layers
        denoiser_x = self.reg_head(readout_token)  													# [B, K, A, F * D]
        denoiser_x = rearrange(denoiser_x, 'b m k a t d -> b m k a (t d)')

        return denoiser_x
