import torch
import torch.nn as nn
import torch.nn.functional as F
from utils import lr_linear_impu, get_target_broadcast
import math
import numpy as np




def apply_rope(x):
    batch_size, seq_len, num_heads, head_dim = x.shape
    half_dim = head_dim // 2

    pos = torch.arange(seq_len, dtype=torch.float32, device=x.device)
    dim = torch.arange(half_dim, dtype=torch.float32, device=x.device)
    freq = 1.0 / (10000 ** (dim / half_dim))

    angle = pos[:, None] * freq[None, :]
    sin, cos = angle.sin(), angle.cos()   # [seq_len, head_dim//2]

    # x: [batch, seq, n_head, head_dim] -> [batch, seq, n_head, head_dim//2, 2]
    x1 = x[..., :half_dim]
    x2 = x[..., half_dim:]
    x_rot = torch.stack([
        x1 * cos[None, :, None, :] - x2 * sin[None, :, None, :],
        x1 * sin[None, :, None, :] + x2 * cos[None, :, None, :]
    ], dim=-1)
    x_rot = x_rot.flatten(-2)  
    return x_rot




class Implicit_Temporal_Func(nn.Module):

    def __init__(self, dim, hidden_dim, down_in, down_out, args, device= "cuda"):
        super(Implicit_Temporal_Func, self).__init__()   
        unfold_dim = args.unfold_dim
        unfold_style = args.unfold_style    
        assert unfold_dim in ["self", "all"]
        assert unfold_style in ["one", "dilation"]
        self.unfold_dim = unfold_dim               
        self.unfold_style = unfold_style            
        self.device = device
        self.tsnet = nn.Sequential(
            nn.Linear(dim + 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, dim),
        ).to(self.device)
        self.down = nn.Conv1d(down_in, down_out, 1).to(self.device)
    

    def unfolding(self, feat):      # feat in [B, T, D]  
        feat = feat.unsqueeze(1).permute(0, 1, 3, 2)        # [B, 1, D, T]
        unfold_feat = F.unfold(
            feat, 
            kernel_size= (1, 3),
            padding= (0, 1)
        )
        unfold_feat = unfold_feat.view(feat.shape[0], feat.shape[1] * 3, feat.shape[2], feat.shape[3])
        unfold_feat = unfold_feat.permute(0, 1, 3, 2)   # [B, 3, T, D]
        return unfold_feat
    

    def coord_gen(self, shape):
        coord_seqs = []
        for i, n in enumerate(shape):
            v0, v1 = -1, 1
            r = (v1 - v0) / (2 * n)
            seq = v0 + r + (2 * r) * torch.arange(n).float()
            coord_seqs.append(seq)
        ret = torch.stack(torch.meshgrid(*coord_seqs), dim=-1)
        return ret


    def ensemble(self, feat, HR_T):
        unfold_feat = self.unfolding(feat)     # [B, C==3, T, D]
        LR_B, LR_C, LR_T, LR_D = unfold_feat.shape
        r = 2 / LR_T / 2      # radius along T axis
        vx_lst = [-1, 1]        
        eps_shift = 1e-6

        coord = self.coord_gen([HR_T, 1]).to(self.device)
        coord = coord.unsqueeze(0).expand(LR_B, HR_T, LR_D, 2)     # [B, HR_T, D, (x,y)]

        # size is [B, 2 (x_i, 0.0), T, D]
        feat_coord = self.coord_gen([LR_T, 1]).to(self.device)
        feat_coord = feat_coord.unsqueeze(0).expand(LR_B, LR_T, LR_D, 2)
        feat_coord = feat_coord.permute(0, 3, 1, 2)
        
        preds, dists = [], []
        ensemble_res = 0
        for vx in vx_lst :
            coord_ = coord.clone()
            coord_[:, :, :, 0] += vx * r + eps_shift
            coord_.clamp_(-1 + eps_shift, 1 - eps_shift)

            q_feat = F.grid_sample(
                input= torch.concat([unfold_feat[ :, :, :, i] for i in range(unfold_feat.shape[-1])]).unsqueeze(-1), 
                grid= coord_[ :, :, 0, :].repeat(LR_D, 1, 1).unsqueeze(2).flip(-1),
                mode= 'nearest', 
                align_corners= False,    
            ).permute(0, 3, 1, 2)
            q_feat = q_feat.view(LR_D, LR_B, q_feat.shape[2], q_feat.shape[3]).permute(1, 0, 2, 3) 
            q_feat = q_feat.permute(0, 2, 3, 1)         # [N, C, T, D]

            q_coord = F.grid_sample(                
                input= torch.concat([feat_coord[ :, :, :, i] for i in range(feat_coord.shape[-1])]).unsqueeze(-1), 
                grid= coord_[ :, :, 0, :].repeat(LR_D, 1, 1).unsqueeze(2).flip(-1),
                mode='nearest', 
                align_corners=False
            ).permute(0, 3, 1, 2)
            q_coord = q_coord.view(LR_D, LR_B, q_coord.shape[2], q_coord.shape[3]).permute(1, 0, 2, 3)
            q_coord = q_coord.permute(0, 2, 3, 1)     # [N, (x,y), T, D]

            # channel last
            q_feat = q_feat.permute(0, 2, 3, 1)
            q_coord = q_coord.permute(0, 2, 3, 1)

            rel_coord = coord - q_coord
            rel_coord[ :, :, :, 0] *= LR_T
            tsnet_input = q_feat.clone()

            b, t, d = tsnet_input.shape[ : 3]
            # pred = self.tsnet(tsnet_input.view(b * t * d, -1))      # sudo batch_size is b * t * d, MLP at C
            pred = q_feat.clone()
            pred = pred.view(b, t, d, -1)                           # [B, T, D, C]
            preds.append(pred)

            dist = torch.abs(rel_coord[ :, :, :, 0])
            dists.append(dist + 1e-9)

        tot_dists = torch.stack(dists).sum(dim=0)
        temp = dists[0]
        dists[0] = dists[1]
        dists[1] = temp
        for pred, dist in zip(preds, dists):
            ensemble_res = ensemble_res + pred * (dist / tot_dists).unsqueeze(-1)
        return ensemble_res
    

    def forward(self, feat, HR_T_list, verbose= False):
        assert len(HR_T_list) >= 1     
        for i in range(len(HR_T_list)):
            HR_T = HR_T_list[i]
            ensemble_input = feat if i == 0 else ensembled
            ensembled = self.ensemble(ensemble_input, HR_T)
            ensembled_shape = ensembled.shape       # [B, T', D, tsnet_out_dim]
            ensembled = ensembled.reshape(ensembled_shape[0] * ensembled_shape[2], ensembled_shape[1], ensembled_shape[-1]).permute(0, 2, 1).contiguous()
            ensembled = self.down(ensembled)
            ensembled = ensembled.permute(0, 2, 1).view(ensembled_shape[ : 3], 1)
            if verbose : print(f"ITF schema {i} with size {ensembled.size()}")
        return ensembled




class SinusoidalPositionalEncoding(nn.Module):

    def __init__(self, d_model, max_len=4096, dropout=0.0):
        super().__init__()
        pe = torch.zeros (max_len, d_model)
        position = torch.arange (0, max_len, dtype=torch.float32).unsqueeze (1)
        div_term = torch.exp (torch.arange (0, d_model, 2, dtype=torch.float32) * (-math.log (10000.0) /d_model))
        pe [:, 0::2] = torch.sin (position * div_term)
        pe [:, 1::2] = torch.cos (position * div_term)
        self.register_buffer ("pe", pe) 
        self.dropout = nn.Dropout (dropout)
        self.d_model = d_model
    
    def forward(self, x, start_pos=0):
        T = x.size(1)
        pos = self.pe[start_pos:start_pos+T].unsqueeze(0).to(x.dtype)
        x = x * math.sqrt(self.d_model) + pos
        return self.dropout(x)




def make_attention_mask(mask, dtype, device):
    return mask[:, None, None, :].to(dtype=dtype, device=device)  # 1.0 for keep, 0.0 for mask




def generate_causal_mask(B, T, device):
    mask = torch.tril(torch.ones(T, T, device=device))
    mask = mask.unsqueeze(0).expand(B, -1, -1)  # [B, T, T]
    return mask




class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads, rope=True, dropout=0.2):
        super().__init__()
        assert d_model % n_heads == 0
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        self.rope = rope
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.attn_dropout = nn.Dropout(dropout)
        self.out_proj = nn.Linear(d_model, d_model)

    def forward(self, x, memory=None, key_mask=None, attn_mask=None):
        B, L, C = x.shape
        if memory is None:
            q = self.q_proj(x)
            k = self.k_proj(x)
            v = self.v_proj(x)
            mem_len = L
        else:
            q = self.q_proj(x)
            k = self.k_proj(memory)
            v = self.v_proj(memory)
            mem_len = memory.shape[1]
        q = q.view(B, L, self.n_heads, self.head_dim)
        k = k.view(B, mem_len, self.n_heads, self.head_dim)
        v = v.view(B, mem_len, self.n_heads, self.head_dim)

        if self.rope:
            q = apply_rope(q)
            k = apply_rope(k)
        # [B, n_head, tgt, src]
        attn = torch.einsum('bqhd,bkhd->bhqk', q, k) / (self.head_dim ** 0.5)
        # attn_mask: [B, T, T], key_mask: [B, T]
        if attn_mask is not None:
            attn_mask_exp = attn_mask[:, None, :, :]  # [B, 1, T, T]
            attn = attn.masked_fill(attn_mask_exp == 0, float('-inf'))
        if key_mask is not None:
            mask = make_attention_mask(key_mask, dtype=attn.dtype, device=attn.device)  # [B, 1, 1, src]
            attn = attn.masked_fill(mask == 0, float('-inf'))
        attn = F.softmax(attn, dim=-1)
        attn = self.attn_dropout(attn)
        out = torch.einsum('bhqk,bkhd->bqhd', attn, v).reshape(B, L, -1)
        return self.out_proj(out)




class DecoderBlock(nn.Module):

    def __init__(self, d_model, n_heads, d_ff, use_rope, dropout=0.1):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.self_attn = MultiHeadAttention(d_model, n_heads, rope=use_rope)
        self.ln2 = nn.LayerNorm(d_model)
        self.cross_attn_1 = MultiHeadAttention(d_model, n_heads, rope=False)
        self.ln3 = nn.LayerNorm(d_model)
        self.cross_attn_2 = MultiHeadAttention(d_model, n_heads, rope=False)
        self.ln4 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        self.ff = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Linear(d_ff, d_model)
        )


    def forward(self, x, cond, lr, causal_mask, mask_c1, mask_c2, task_type, ln_method= "pre"):
        # mask_self: [batch, T]
        # mask_c1: [batch, T]
        # mask_c2: [batch, T]
        assert ln_method in ["pre", "post"], "ln_method must be 'pre' or 'post'"
        if ln_method == "pre" :
            # 1. Masked self-attention
            x = x + self.self_attn(self.ln1(x), attn_mask=causal_mask)
            # 2. Cross attention #1 
            x = x + self.cross_attn_1(self.ln2(x), memory=cond, key_mask=mask_c1)
            # 3. Cross attention #2 
            # x = x + self.cross_attn_2(self.ln3(x), memory=lr, key_mask=mask_c2)
            x = x + mask_c2.unsqueeze(-1) * self.cross_attn_2(self.ln3(x), memory=lr)      
            # 4. FFN
            x = x + self.dropout(self.ff(self.ln4(x)))
        else :
            # 1. Masked self-attention
            out = self.self_attn(x, attn_mask=causal_mask)
            x = self.ln1(x + out)
            # 2. Cross attention #1 
            out = self.cross_attn_1(x, memory=cond, key_mask=mask_c1)
            x = self.ln2(x + out)
            # 3. Cross attention #2 
            out = mask_c2.unsqueeze(-1) *  self.cross_attn_2(x, memory=lr)
            x = self.ln3(x + out)
            # 4. FFN
            out = self.ff(x)
            out = self.dropout(out)
            x = self.ln4(x + out)
        return x




class CrossMaskedDecoder(nn.Module):

    def __init__(self, x_dim, c_dim, d_model, d_ff, n_heads, num_layers, itf_dim, itf_hidden, itf_schema, args, combined= True,
                 itf= True, task_type= "SSR", device= "cuda"):
        super().__init__()
        # if not combined : c_dim = c_dim // 2
        self.x_embed = nn.Linear(x_dim, d_model).to(device)
        if args.time_emd : self.t_embed = nn.Linear(1, args.t_dim).to(device)
        self.c_embed = nn.Linear(c_dim, d_model).to(device)    
        self.pos_enc = SinusoidalPositionalEncoding(d_model).to(device)
        self.layers = nn.ModuleList([
            DecoderBlock(d_model, n_heads, d_ff, not args.sinu_pe) for _ in range(num_layers)
        ]).to(device)
        self.ln_final = nn.LayerNorm(d_model).to(device)
        self.out_proj = nn.Linear(d_model, x_dim).to(device)

        self.sinu_pos = args.sinu_pe
        self.task_type = task_type
        self.combined = combined
        self.use_itf = itf
        self.itf_schema = itf_schema
        self.device = device
        
        self.itf = Implicit_Temporal_Func(
            dim= itf_dim,
            hidden_dim= itf_hidden,
            down_in= itf_dim,
            down_out= 1,
            args= args,
            device= device,
        )
        self.args = args


    def forward(self, x_t, cond, lr, t, target_mask, use= "both", debug= False):
        # x: [batch, T, x_dim]
        # c: [batch, T, c_dim]
        # m1, m2: [batch, T]
        x = self.x_embed(x_t)   # [B, T, d_model]
        if self.sinu_pos : x += self.pos_enc(x)

        t = t.float().unsqueeze(1)

        if self.combined :
            cs, ct = cond[0], cond[1]
            if self.use_itf : 
                cs = self.itf(cs, self.itf_schema)
                ct = self.itf(ct, self.itf_schema)
            else : 
                cs = lr_linear_impu(cs, x.size(1), self.device)
                ct = lr_linear_impu(ct, x.size(1), self.device)
            if use == "both" :
                c = torch.cat([cs, ct], dim= 2)  # [B, T, d_model]
            elif use == "s" :
                c = cs
            elif use == "t" :
                c = ct
            c = self.c_embed(c)   # [B, T, d_model]
            c += self.pos_enc(c.clone())
        else :
            pass

        if self.task_type == "SSR" : 
            if use == "both" : lr_input = torch.cat([lr, lr], dim= 2)
            else : lr_input = lr
            lr = self.c_embed(lr_input)   # [B, T, d_model]
            lr += self.pos_enc(lr.clone())
        elif self.task_type == "ASR" :
            lr_input = get_target_broadcast(lr, target_mask)
            if use == "both" :
                lr = self.c_embed(torch.cat([lr_input, lr_input], dim= 2))   # [B, T, d_model]
            else :
                lr = self.c_embed(lr_input)   # [B, T, d_model]
            lr += self.pos_enc(lr.clone())
        else : raise ValueError("task_type must be 'SSR' or 'ASR'")

        B, T, _ = x.shape
        m2 = torch.from_numpy(1 - target_mask).unsqueeze(0).repeat(B, 1).to(x.device)
        # m1 = torch.from_numpy(np.ones_like(target_mask)).unsqueeze(0).repeat(B, 1).to(x.device)
        m1 = None
        # m1 = torch.from_numpy(target_mask).unsqueeze(0).repeat(B, 1).to(x.device)
        # m2 = torch.from_numpy(1 - target_mask).unsqueeze(0).repeat(B, 1).to(x.device)

        if debug :
            print("SIZE : ", x_t.size(), cond.size(), m1.size(), m2.size())

        # autoreg_mask = torch.tril(torch.ones(T, T, device=x.device)).unsqueeze(0).repeat(B, 1, 1)  # [B, T, T]
        # autoreg_mask = autoreg_mask * m1.unsqueeze(1)  

        # Layer-wise computation
        for layer in self.layers:
            x = layer(x, c, lr, generate_causal_mask(B, T, x.device), m1, m2, self.task_type)

        x = self.ln_final(x)
        y = self.out_proj(x)
        return y, cs, ct




if __name__ == "__main__" :

    velocity_predictor = CrossMaskedDecoder(
        x_dim= 7, 
        c_dim= 7, 
        d_model= 512, 
        d_ff= 512,
        n_heads= 16, 
        num_layers= 8, 
        itf_dim= 3,
        itf_hidden= 128,
        itf_schema= [128, 337],
        task_type= "SSR",
        device= "cuda",
        itf= True,
    )

    total_params = sum(p.numel() for p in velocity_predictor.parameters() if p.requires_grad)
    print(f"param scale : {total_params}")