import numpy as np
import torch
import torch.nn as nn

from models.utils import polyline_encoder
from models.context_encoder.mtr_encoder import SinusoidalPosEmb
from einops import rearrange
import math

class HistGRU(nn.Module):
    def __init__(self, in_dim, d_model):
        super().__init__()
        self.gru = nn.GRU(
            input_size=in_dim,
            hidden_size=d_model // 2,
            num_layers=1,
            batch_first=True,
            bidirectional=True,
        )
        self.proj = nn.Linear(d_model, d_model) 

        self.score = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Linear(d_model, d_model),
            nn.ReLU(),
            nn.Linear(d_model, 1)
        )


    def forward(self, hist_feats, mask=None):  # [B, V, T, C_h]
        B, V, T, C = hist_feats.shape
        x = hist_feats.contiguous().view(B * V, T, C)   # -> [B*V, T, C]
        y, h = self.gru(x)                               # -> [B*V, T, d_model]
        x_agent = h.squeeze(0).reshape(B, V, -1)  # [B,V,d]
        y_last = y[:, -1, :]                             
        y_last = self.proj(y_last)                       # -> [B*V, d_model]

        s = self.score(y_last.view(B, V, -1)).squeeze(-1)           # [B,V]
        if mask is not None:
            s = s.masked_fill((1-mask).bool(), float("-inf"))
        w = torch.softmax(s, dim=1)             # [B,V]
        y = (x_agent * w.unsqueeze(-1)).sum(dim=1)  # [B,d]

        return y


class ZEncoder(nn.Module):
    def __init__(self, d_hist, d_cue, d_model, d_z):
        super().__init__()
        self.core = ConditionEncoder(d_hist, d_cue, d_goal=0, d_zd=0, d_zc=0, d_model=d_model)
        self.head_cond = nn.Linear(d_model, d_model)  
        self.head_z    = nn.Linear(d_model, d_z)      

    def forward(self, x_data):
        h      = self.core(x_data)              # [B, d_model]
        cond   = self.head_cond(h)              # [B, d_model]
        z      = self.head_z(h)                 # [B, d_z]   
        return cond, z


class ZFiLM(nn.Module):
    def __init__(self, d_feat):
        super().__init__()
        self.gamma = nn.Linear(d_feat, d_feat)
        self.beta  = nn.Linear(d_feat, d_feat)

    def forward(self, feat, z):   # feat:[B,L,D], z:[B,D]
        g, b = self.gamma(z), self.beta(z)   # [B,D]
        return feat * (1 + g[:,None,:]) + b[:,None,:]


class ConditionEncoder(nn.Module):
    def __init__(self, d_hist, d_cue, d_goal, d_zd, d_zc, d_model):
        super().__init__()
        def branch(d_in):
            if d_in <= 0: return None
            enc = nn.GRU(d_in, d_model//2, num_layers=1, batch_first=True, bidirectional=True)
            proj = nn.Linear(d_model, d_model)
            return nn.ModuleDict({'enc': enc, 'proj': proj})

        self.br_hist = HistGRU(d_hist, d_model) 
        self.br_cue  = branch(d_cue)  
        self.br_goal = branch(d_goal) 
        self.br_zd   = branch(d_zd)   
        self.br_zc   = branch(d_zc)
        in_cat = sum([d_model for b in [self.br_cue,self.br_goal,self.br_zd,self.br_zc] if b is not None])
        self.fuse = nn.Sequential(
            nn.Linear(in_cat if in_cat>0 else d_model, d_model),
            nn.ReLU(),
            nn.Linear(d_model, d_model)
        )

    def _enc_one(self, br, x):  # x: [B, T, C]
        if br is None or x is None: return None
        h, _ = br['enc'](x)                  
        s = torch.mean(h, dim=1)             
        return br['proj'](s)                 

    def forward(self, inputs: dict):
        outs = []
        # outs += [self.br_hist(inputs['hist_feats'])]
        cond_cue = torch.cat([inputs['hist_cond_cue'], inputs['fut_cond_cue']], dim=1)
        outs += [self._enc_one(self.br_cue , cond_cue)]
        # outs += [self._enc_one(self.br_goal, inputs.get('goal_rel'))]
        # outs += [self._enc_one(self.br_zd  , inputs.get('z_d'))]
        # outs += [self._enc_one(self.br_zc  , inputs.get('z_c'))]
        outs = [o for o in outs if o is not None]
        if len(outs) == 0:
            # 回退：没有条件就给零向量
            return torch.zeros(self.fuse[0].in_features, device=next(self.parameters()).device)
        cond = torch.cat(outs, dim=-1)
        return self.fuse(cond)  # [B, d_model]