import torch
import torch.nn as nn
import math
from ITF import Implicit_Temporal_Func




def get_sinusoidal_positional_encoding(seq_len, dim, device):
    positions = torch.arange(seq_len, dtype=torch.float, device=device).unsqueeze(1)  # [seq_len, 1]
    div_term = torch.exp(torch.arange(0, dim, 2, dtype=torch.float, device=device) * (-math.log(10000.0) / dim))
    pe = torch.zeros(seq_len, dim, device=device)
    pe[:, 0::2] = torch.sin(positions * div_term)
    pe[:, 1::2] = torch.cos(positions * div_term)
    return pe  # [seq_len, dim]




class TimeSeriesGenerator(nn.Module):
    def __init__(
        self,
        input_dim,      # D
        cond_dim, 
        itf_dim, itf_hidden, itf_schema,
        unfold_dim= "self", unfold_style= "one",       # C
        t_dim=8,        # t embed dim
        hidden_dim=64, # Transformer model dim
        num_layers=6,   
        num_heads=4,    
        device='cpu'
    ):
        super(TimeSeriesGenerator, self).__init__()
        self.device = device
        self.hidden_dim = hidden_dim

        # Embedding
        self.input_embed = nn.Linear(input_dim, hidden_dim)
        self.cond_embed = nn.Linear(cond_dim, hidden_dim)
        self.t_embed = nn.Linear(1, t_dim)  
        self.pos_dim = hidden_dim           

        self.concat_dim = hidden_dim * 2 + t_dim + self.pos_dim

        self.itf_schema = itf_schema
        self.itf = Implicit_Temporal_Func(
            dim= itf_dim,
            hidden_dim= itf_hidden,
            down_in= itf_dim,
            down_out= 1,
            unfold_dim= unfold_dim,
            unfold_style= unfold_style,
            device= device,
        )

        self.input_proj = nn.Linear(self.concat_dim, hidden_dim)
        self.target_embed = nn.Linear(input_dim, hidden_dim)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim, nhead=num_heads, dim_feedforward=hidden_dim*4
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        decoder_layer = nn.TransformerDecoderLayer(
            d_model=hidden_dim, nhead=num_heads, dim_feedforward=hidden_dim*4
        )
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)

        self.out_proj = nn.Linear(hidden_dim, input_dim)
        self.to(device)

    def forward(self, src_A, t, cond_B, tgt_seq=None):

        B, T, D = src_A.size()
        device = src_A.device

        pos_enc = get_sinusoidal_positional_encoding(T, self.pos_dim, device)  # [T, pos_dim]
        pos_enc = pos_enc.unsqueeze(0).repeat(B, 1, 1)  # [B, T, pos_dim]

        t = t.float().unsqueeze(1)                  # [B,1]
        t_emb = self.t_embed(t)                     # [B, t_dim]
        t_emb = t_emb.unsqueeze(1).repeat(1, T, 1)  # [B, T, t_dim]

        input_emb = self.input_embed(src_A)         # [B, T, hidden_dim]
        cond_itf = self.itf(cond_B, self.itf_schema)
        cond_emb = self.cond_embed(cond_itf)          # [B, T, hidden_dim]

        encoder_input_cat = torch.cat([input_emb, cond_emb, t_emb, pos_enc], dim=-1) # [B, T, concat_dim]

        encoder_input = self.input_proj(encoder_input_cat)       # [B, T, hidden_dim]
        encoder_input = encoder_input.transpose(0,1)             # [T, B, hidden_dim]
        encoder_hidden = self.encoder(encoder_input)

        if tgt_seq is not None:
            tgt_emb = self.target_embed(tgt_seq)                 # [B, T, hidden_dim]
            decoder_input_cat = torch.cat([tgt_emb, t_emb, pos_enc], dim=-1) # [B, T, hidden_dim + t_dim + pos_dim]
            decoder_input = self.input_proj(decoder_input_cat)   # [B, T, hidden_dim]
            decoder_input = decoder_input.transpose(0,1)         # [T, B, hidden_dim]
            decoder_out = self.decoder(decoder_input, encoder_hidden)
            output = self.out_proj(decoder_out.transpose(0,1))   # [B, T, D]
            return output
        else:
            gen_seq = []
            prev = torch.zeros(B, D, device=device)
            for t_idx in range(T):
                if t_idx == 0:
                    tgt_emb = self.target_embed(prev).unsqueeze(1)   # [B, 1, hidden_dim]
                    t_emb_step = t_emb[:, :1, :]                    # [B, 1, t_dim]
                    pos_enc_step = pos_enc[:, :1, :]                # [B, 1, pos_dim]
                    decoder_input_cat = torch.cat([tgt_emb, tgt_emb, t_emb_step, pos_enc_step], dim=-1) # [B, 1, ...]
                else:
                    tgt = torch.stack(gen_seq, dim=1)               # [B, t_idx, D]
                    tgt_emb = self.target_embed(tgt)                # [B, t_idx, hidden_dim]
                    t_emb_step = t_emb[:, :t_idx, :]                # [B, t_idx, t_dim]
                    pos_enc_step = pos_enc[:, :t_idx, :]            # [B, t_idx, pos_dim]
                    decoder_input_cat = torch.cat([tgt_emb, tgt_emb, t_emb_step, pos_enc_step], dim=-1)# [B, t_idx, ...]
                decoder_input = self.input_proj(decoder_input_cat)   # [B, t_idx, hidden_dim]
                decoder_input = decoder_input.transpose(0,1)         # [t_idx, B, hidden_dim]
                decoder_out = self.decoder(decoder_input, encoder_hidden)
                next_token = self.out_proj(decoder_out[-1])          # [B, D]
                gen_seq.append(next_token)
            output = torch.stack(gen_seq, dim=1)                     # [B, T, D]
            return output, cond_itf




if __name__ == "__main__":

    velocity_predictor = TimeSeriesGenerator(
            input_dim= 7 , 
            cond_dim= 7, 
            hidden_dim= 64, 
            num_layers= 6, 
            num_heads= 4, 
            itf_dim= 3,
            itf_hidden= 128,
            itf_schema= [128, 337],
            device= "cuda",
        )
    
    total_params = sum(p.numel() for p in velocity_predictor.parameters())
    print(f"param scale : {total_params}")