import torch
import torch.nn as nn
from ITF import Implicit_Temporal_Func




class ConvBlock1D(nn.Module):
    def __init__(self, in_ch, out_ch, dropout=0.2):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv1d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm1d(out_ch),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.block(x)




class UNet1D(nn.Module):
    def __init__(self, input_dim, covariate_dim, output_dim, itf_dim, itf_hidden, itf_schema, t_dim= 1, chs=[32, 128, 1024, 128, 32],
                 unfold_dim= "self", unfold_style= "one", dropout=0.2, device= "cuda"):
        super().__init__()
        total_in_dim = input_dim + covariate_dim + t_dim

        self.down_blocks = nn.ModuleList()
        self.pools = nn.ModuleList()
        prev_ch = total_in_dim
        for ch in chs[:len(chs)//2]:
            self.down_blocks.append(ConvBlock1D(prev_ch, ch, dropout))
            self.pools.append(nn.MaxPool1d(2))
            prev_ch = ch
        self.bottleneck = ConvBlock1D(prev_ch, chs[len(chs)//2], dropout)
        prev_ch = chs[len(chs)//2]
        self.up_blocks = nn.ModuleList()
        self.up_convs = nn.ModuleList()
        for ch in chs[len(chs)//2+1:]:
            self.up_convs.append(nn.ConvTranspose1d(prev_ch, ch, 2, stride=2))
            self.up_blocks.append(ConvBlock1D(ch*2, ch, dropout))
            prev_ch = ch
        self.output_conv = nn.Conv1d(prev_ch, output_dim, 1)
        self.device = device
        self.to(device)

        self.itf_schema = itf_schema
        self.itf = Implicit_Temporal_Func(
            dim= itf_dim,
            hidden_dim= itf_hidden,
            down_in= itf_dim,
            down_out= 1,
            unfold_dim= unfold_dim,
            unfold_style= unfold_style,
            device= device,
        )


    def forward(self, X, t, C):
        X = X.to(self.device)
        t = t.to(self.device)
        C = C.to(self.device)

        B, L, _ = X.shape
        if t.dim() == 1:
            t = t.unsqueeze(-1)
        t_feat = t.unsqueeze(1).repeat(1, L, 1)
        cond_itf = self.itf(C, self.itf_schema)
        x_cat = torch.cat([X, cond_itf, t_feat], dim=-1)   # [B, L, *) 
        out = x_cat.permute(0, 2, 1)                # [B, in_ch, L]

        # Encoder
        enc_feats = []
        for block, pool in zip(self.down_blocks, self.pools):
            out = block(out)
            enc_feats.append(out)
            out = pool(out)

        # Bottleneck
        out = self.bottleneck(out)

        # Decoder
        for up_conv, up_block, enc_feat in zip(self.up_convs, self.up_blocks, reversed(enc_feats)):
            out = up_conv(out)
            diff = enc_feat.shape[-1] - out.shape[-1]
            if diff > 0:
                out = nn.functional.pad(out, (0, diff))
            elif diff < 0:
                enc_feat = nn.functional.pad(enc_feat, (0, -diff))
            out = torch.cat([out, enc_feat], dim=1)
            out = up_block(out)

        y = self.output_conv(out)          # [B, output_dim, L]
        y = y.permute(0, 2, 1)             # [B, L, output_dim]
        return y, cond_itf



if __name__ == "__main__" :

    velocity_predictor = UNet1D(
        input_dim= 7,
        covariate_dim= 7,
        t_dim= 1,
        chs= [32, 128, 1024, 128, 32],
        output_dim= 7,
        itf_dim= 3,
        itf_hidden= 128,
        itf_schema= [128, 337],
        dropout=0.2,
        device= "cuda",
    )
    total_params = sum(p.numel() for p in velocity_predictor.parameters())
    print(f"param scale : {total_params}")