import torch
import torch.nn as nn
import torch.nn.functional as F
from layers.Basic import MLP
from layers.Embedding import timestep_embedding, unified_pos_embedding
from layers.UNet_Blocks import DoubleConv1D, Down1D, Up1D, OutConv1D, DoubleConv2D, Down2D, Up2D, OutConv2D, \
    DoubleConv3D, Down3D, Up3D, OutConv3D

from baselines.model_factory import U_NetConfig

ConvList = [None, DoubleConv1D, DoubleConv2D, DoubleConv3D]
DownList = [None, Down1D, Down2D, Down3D]
UpList = [None, Up1D, Up2D, Up3D]
OutList = [None, OutConv1D, OutConv2D, OutConv3D]


class Model(nn.Module):
    """
    fx:[B,N,Cin], x:[B,N,d] -> [B,N,Cout]
    """
    def __init__(self, model_cfg: U_NetConfig, bilinear=True):
        super(Model, self).__init__()
        self.__name__ = 'U_Net'
        self.normtype = "in"
        self.cfg = model_cfg
        self.spatial_dim = len(self.cfg.shapelist)
        
        # modules
        ## embedding
        if self.cfg.pos_emb:
            pos = unified_pos_embedding(self.cfg.shapelist, self.cfg.ref, device=self.cfg.device)  # [1,N,ref^d]
            self.register_buffer('pos', pos, persistent=False)
            self.preprocess = MLP(self.cfg.in_channels + self.cfg.ref ** self.spatial_dim, self.cfg.hidden_channels * 2,
                                self.cfg.hidden_channels, n_layers=0, res=False, act=self.cfg.activation)
        else:
            self.preprocess = MLP(self.cfg.in_channels + self.spatial_dim, self.cfg.hidden_channels * 2, self.cfg.hidden_channels,
                                n_layers=0, res=False, act=self.cfg.activation)
        if self.cfg.time_input:
            self.time_fc = nn.Sequential(nn.Linear(self.cfg.hidden_channels, self.cfg.hidden_channels), nn.SiLU(),
                                        nn.Linear(self.cfg.hidden_channels, self.cfg.hidden_channels))

        patch_size = [(size + (16 - size % 16) % 16) // 16 for size in self.cfg.shapelist]
        self.padding = [(16 - size % 16) % 16 for size in self.cfg.shapelist]
        # ------------------------------------------------------------------------------------------------------
        # multiscale modules
        self.inc = ConvList[len(patch_size)](self.cfg.hidden_channels, self.cfg.hidden_channels, normtype=self.normtype)
        self.down1 = DownList[len(patch_size)](self.cfg.hidden_channels, self.cfg.hidden_channels * 2, normtype=self.normtype)
        self.down2 = DownList[len(patch_size)](self.cfg.hidden_channels * 2, self.cfg.hidden_channels * 4, normtype=self.normtype)
        self.down3 = DownList[len(patch_size)](self.cfg.hidden_channels * 4, self.cfg.hidden_channels * 8, normtype=self.normtype)
        factor = 2 if bilinear else 1
        self.down4 = DownList[len(patch_size)](self.cfg.hidden_channels * 8, self.cfg.hidden_channels * 16 // factor, normtype=self.normtype)
        self.up1 = UpList[len(patch_size)](self.cfg.hidden_channels * 16, self.cfg.hidden_channels * 8 // factor, bilinear, normtype=self.normtype)
        self.up2 = UpList[len(patch_size)](self.cfg.hidden_channels * 8, self.cfg.hidden_channels * 4 // factor, bilinear, normtype=self.normtype)
        self.up3 = UpList[len(patch_size)](self.cfg.hidden_channels * 4, self.cfg.hidden_channels * 2 // factor, bilinear, normtype=self.normtype)
        self.up4 = UpList[len(patch_size)](self.cfg.hidden_channels * 2, self.cfg.hidden_channels, bilinear, normtype=self.normtype)
        self.outc = OutList[len(patch_size)](self.cfg.hidden_channels, self.cfg.hidden_channels)
        # projectors
        self.fc1 = nn.Linear(self.cfg.hidden_channels, self.cfg.hidden_channels)
        self.fc2 = nn.Linear(self.cfg.hidden_channels, self.cfg.out_channels)


    def multiscale(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        x = self.outc(x)
        return x
    

    def forward(self, fx: torch.Tensor, x: torch.Tensor | None = None, T: torch.Tensor | None = None):
        # x: [B, N_points, spatial_dim], fx: [B, N_points, in_channels]
        B, N, _ = fx.shape
        if x is None:
            assert self.cfg.pos_emb
        if x is None or self.cfg.pos_emb:
            x = self.pos.expand(B, -1, -1)    # [B, N_points, ref^d]
        fx = torch.cat((x, fx), -1)
        fx = self.preprocess(fx)    # [B, N_points, hidden_channels]

        # time embedding
        if self.cfg.time_input and (T is not None):
            Time_emb = timestep_embedding(T, self.cfg.hidden_channels).unsqueeze(1).expand(-1, fx.shape[1], -1)    # [B, N_points, hidden_channels]
            Time_emb = self.time_fc(Time_emb)
            fx = fx + Time_emb
        x = fx.permute(0, 2, 1).reshape(B, self.cfg.hidden_channels, *self.cfg.shapelist)
        # ------------------------------------------------------------------------------------------------------
        if not all(item == 0 for item in self.padding):
            if len(self.cfg.shapelist) == 2:
                x = F.pad(x, [0, self.padding[1], 0, self.padding[0]])
            elif len(self.cfg.shapelist) == 3:
                x = F.pad(x, [0, self.padding[2], 0, self.padding[1], 0, self.padding[0]])
        x = self.multiscale(x) ## U-Net
        if not all(item == 0 for item in self.padding):
            if len(self.cfg.shapelist) == 2:
                x = x[..., :-self.padding[0], :-self.padding[1]]
            elif len(self.cfg.shapelist) == 3:
                x = x[..., :-self.padding[0], :-self.padding[1], :-self.padding[2]]
        x = x.reshape(B, self.cfg.hidden_channels, -1).permute(0, 2, 1)
        x = self.fc1(x)
        x = F.gelu(x)
        x = self.fc2(x)
        return x

    