import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils import lr_linear_impu
from ITF import Implicit_Temporal_Func
from utils import TIME_ENCODING




class Temporal_Block(nn.Module):

    def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout= 0.2):
        super(Temporal_Block, self).__init__()
        self.in_channels = n_inputs
        self.out_channels = n_outputs
        self.conv1 = nn.Conv1d(n_inputs, n_outputs, kernel_size, stride= stride, padding= padding, dilation= dilation)
        self.conv2 = nn.Conv1d(n_outputs, n_outputs, kernel_size, stride= stride, padding= padding, dilation= dilation)
        self.dropout = nn.Dropout(dropout)
        self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None
        self.activation = nn.ReLU()

        self.conv1.weight.data.normal_(0, 0.01).float()
        self.conv2.weight.data.normal_(0, 0.01).float()
        if self.downsample:
            self.downsample.weight.data.normal_(0, 0.01).float()
    

    def forward(self, x, verbose= False):
        out = self.conv1(x)
        out = self.activation(out)
        out = self.dropout(out)
        out = self.conv2(out)
        out = self.activation(out)
        res = x if self.downsample is None else self.downsample(x)
        out = self.activation(out + res)
        if verbose :
            print("TCN block input size : ", x.size())
            print("TCN block output size : ", out.size())
        return out




class Predictor(nn.Module):

    def __init__(self, input_dim, cond_dim, itf_dim, itf_hidden, seq_len, itf_schema, itf_first= False, channels_list= [64, 64, 256, 256], kernel_size= 3,
                 unfold_dim= "self", unfold_style= "one", dropout= 0.1, itf= True, timebed= False, time_emb_dim= 32, device= "cuda"):
        super(Predictor, self).__init__()
        self.input_dim = input_dim
        self.cond_dim = cond_dim
        self.time_emb_dim = time_emb_dim
        self.channels = channels_list
        self.block_num = len(channels_list)
        self.kernel_size = kernel_size
        self.seq_len = seq_len
        self.device = device
        self.timembed = timebed
        self.use_itf = itf
        self.itf_schema = itf_schema
        self.itf_first = itf_first

        self.time_encoding = TIME_ENCODING(
            input_dim= time_emb_dim,
            hidden_dim= time_emb_dim,
            output_dim= time_emb_dim,
        ).to(device)
        
        self.itf = Implicit_Temporal_Func(
            dim= itf_dim,
            hidden_dim= itf_hidden,
            down_in= itf_dim,
            down_out= 1,
            unfold_dim= unfold_dim,
            unfold_style= unfold_style,
            device= device,
        )

        self.blocks = []
        for i in range(self.block_num):
            dilation_size = 2 ** i
            in_channels = input_dim + cond_dim + (time_emb_dim if timebed else 0) if i == 0 else channels_list[i - 1]
            out_channels = channels_list[i]
            padding = (kernel_size - 1) * dilation_size // 2
            self.blocks.append(Temporal_Block(
                n_inputs= in_channels,
                n_outputs= out_channels,
                kernel_size= kernel_size,
                stride= 1,
                dilation= dilation_size,
                padding= padding,
                dropout= dropout,
            ))
        self.blocks.append(nn.Conv1d(self.channels[-1], input_dim, 1))
        self.tcn = nn.Sequential(*self.blocks).to(device)
               

    def forward(self, x_t, t, cond= None, m= "con"):      # input x_t and cond size (B, T, D)

        x_t = x_t.permute(0, 2, 1).to(self.device)       # (B, D, T)
        if cond is not None :
            if self.use_itf :
                cond = self.itf(cond, self.itf_schema)          
                cond = cond.permute(0, 2, 1).to(self.device)
            else :
                cond = lr_linear_impu(cond, self.seq_len, self.device).permute(0, 2, 1)
        if m == "pro" :
            cond_proj = nn.Conv1d(in_channels= cond.shape[1], out_channels= 2 * cond.shape[1], kernel_size= 1).to(self.device)
            cond_p = cond_proj(cond)
            if self.timembed :
                t_emb = self.time_encoding(t).unsqueeze(-1).expand(-1, -1, self.seq_len)
                if cond is not None : x = torch.cat([x_t + cond_p, t_emb], dim= 1)
                else : x = torch.cat([x_t, t_emb], dim= 1)
            else :
                if cond is not None : x = torch.cat([x_t + cond_p], dim= 1)
                else : x = x_t
            x = self.tcn(x)
            return x.permute(0, 2, 1), cond.permute(0, 2, 1) if cond is not None else None
        elif m == "con" :
            if self.timembed :
                t_emb = self.time_encoding(t).unsqueeze(-1).expand(-1, -1, self.seq_len)
                if cond is not None : x = torch.cat([x_t, t_emb, cond], dim= 1)
                else : x = torch.cat([x_t, t_emb], dim= 1)
            else :
                if cond is not None : x = torch.cat([x_t, cond], dim= 1)
                else : x = x_t
            x = self.tcn(x)
            return x.permute(0, 2, 1), cond.permute(0, 2, 1) if cond is not None else None
        else : raise NotImplementedError




if __name__ == "__main__" :

    velocity_predictor = Predictor(
        input_dim= 7,       
        cond_dim= 7,
        channels_list= [64, 64, 256, 256],
        kernel_size= 3,
        itf_dim= 3,
        itf_hidden= 128,
        seq_len= 337,
        itf_schema= [128, 337],
        dropout= 0.1,
        device= "cuda",
        itf= True,
        timebed= False,
    )

    total_params = sum(p.numel() for p in velocity_predictor.parameters())
    print(f"param scale : {total_params}")