import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from layers.Basic import MLP
from layers.Embedding import timestep_embedding, unified_pos_embedding
from layers.FNO_Layers import SpectralConv1d, SpectralConv2d, SpectralConv3d
from layers.UNet_Blocks import DoubleConv1D, Down1D, Up1D, OutConv1D, DoubleConv2D, Down2D, Up2D, OutConv2D, \
    DoubleConv3D, Down3D, Up3D, OutConv3D

from baselines.model_factory import U_NOConfig

ConvList = [None, DoubleConv1D, DoubleConv2D, DoubleConv3D]
DownList = [None, Down1D, Down2D, Down3D]
UpList = [None, Up1D, Up2D, Up3D]
OutList = [None, OutConv1D, OutConv2D, OutConv3D]
BlockList = [None, SpectralConv1d, SpectralConv2d, SpectralConv3d]


class Model(nn.Module):
    """
    fx:[B,N,Cin], x:[B,N,d] -> [B,N,Cout]
    """
    def __init__(self, model_cfg: U_NOConfig, bilinear=True):
        super(Model, self).__init__()
        self.__name__ = 'U_NO'
        self.normtype = "in"
        self.cfg = model_cfg
        self.spatial_dim = len(self.cfg.shapelist)
        
        # modules
        ## embedding
        if self.cfg.pos_emb:
            pos = unified_pos_embedding(self.cfg.shapelist, self.cfg.ref, device=self.cfg.device)  # [1,N,ref^d]
            self.register_buffer('pos', pos, persistent=False)
            self.preprocess = MLP(self.cfg.in_channels + self.cfg.ref ** self.spatial_dim, self.cfg.hidden_channels * 2,
                                self.cfg.hidden_channels, n_layers=0, res=False, act=self.cfg.activation)
        else:
            self.preprocess = MLP(self.cfg.in_channels + self.spatial_dim, self.cfg.hidden_channels * 2, self.cfg.hidden_channels,
                                n_layers=0, res=False, act=self.cfg.activation)
        if self.cfg.time_input:
            self.time_fc = nn.Sequential(nn.Linear(self.cfg.hidden_channels, self.cfg.hidden_channels), nn.SiLU(),
                                        nn.Linear(self.cfg.hidden_channels, self.cfg.hidden_channels))

        patch_size = [(size + (16 - size % 16) % 16) // 16 for size in self.cfg.shapelist]
        self.padding = [(16 - size % 16) % 16 for size in self.cfg.shapelist]
        # ------------------------------------------------------------------------------------------------------
        self.augmented_resolution = [shape + padding for shape, padding in zip(self.cfg.shapelist, self.padding)]
        if isinstance(self.cfg.n_modes, (tuple, list)):
            n_modes_dim = list(map(int, self.cfg.n_modes))
            if len(n_modes_dim) != len(self.augmented_resolution):
                n_min = min(n_modes_dim)
                n_modes_dim = [n_min] * len(self.augmented_resolution)
        else:
            n_modes_dim = [int(self.cfg.n_modes)] * len(self.augmented_resolution)

        def _nm(scale: int):
            return [
                max(1, min(n_modes_dim[d], self.augmented_resolution[d] // scale))
                for d in range(len(self.augmented_resolution))
            ]

        self.inc = ConvList[len(patch_size)](self.cfg.hidden_channels, self.cfg.hidden_channels, normtype=self.normtype)
        self.down1 = DownList[len(patch_size)](self.cfg.hidden_channels, self.cfg.hidden_channels * 2, normtype=self.normtype)
        self.down2 = DownList[len(patch_size)](self.cfg.hidden_channels * 2, self.cfg.hidden_channels * 4, normtype=self.normtype)
        self.down3 = DownList[len(patch_size)](self.cfg.hidden_channels * 4, self.cfg.hidden_channels * 8, normtype=self.normtype)
        factor = 2 if bilinear else 1
        self.down4 = DownList[len(patch_size)](self.cfg.hidden_channels * 8, self.cfg.hidden_channels * 16 // factor, normtype=self.normtype)
        self.up1 = UpList[len(patch_size)](self.cfg.hidden_channels * 16, self.cfg.hidden_channels * 8 // factor, bilinear, normtype=self.normtype)
        self.up2 = UpList[len(patch_size)](self.cfg.hidden_channels * 8, self.cfg.hidden_channels * 4 // factor, bilinear, normtype=self.normtype)
        self.up3 = UpList[len(patch_size)](self.cfg.hidden_channels * 4, self.cfg.hidden_channels * 2 // factor, bilinear, normtype=self.normtype)
        self.up4 = UpList[len(patch_size)](self.cfg.hidden_channels * 2, self.cfg.hidden_channels, bilinear, normtype=self.normtype)
        self.outc = OutList[len(patch_size)](self.cfg.hidden_channels, self.cfg.hidden_channels)
        # Down FNO
        self.process1_down = BlockList[len(patch_size)](self.cfg.hidden_channels, self.cfg.hidden_channels, *_nm(2))
        self.process2_down = BlockList[len(patch_size)](self.cfg.hidden_channels * 2, self.cfg.hidden_channels * 2, *_nm(4))
        self.process3_down = BlockList[len(patch_size)](self.cfg.hidden_channels * 4, self.cfg.hidden_channels * 4, *_nm(8))
        self.process4_down = BlockList[len(patch_size)](self.cfg.hidden_channels * 8, self.cfg.hidden_channels * 8, *_nm(16))
        self.process5_down = BlockList[len(patch_size)](self.cfg.hidden_channels * 16 // factor, self.cfg.hidden_channels * 16 // factor, *_nm(32))
        self.w1_down = ConvList[len(self.padding)](self.cfg.hidden_channels, self.cfg.hidden_channels, 1)
        self.w2_down = ConvList[len(self.padding)](self.cfg.hidden_channels * 2, self.cfg.hidden_channels * 2, 1)
        self.w3_down = ConvList[len(self.padding)](self.cfg.hidden_channels * 4, self.cfg.hidden_channels * 4, 1)
        self.w4_down = ConvList[len(self.padding)](self.cfg.hidden_channels * 8, self.cfg.hidden_channels * 8, 1)
        self.w5_down = ConvList[len(self.padding)](self.cfg.hidden_channels * 16 // factor, self.cfg.hidden_channels * 16 // factor, 1)
        # Up FNO
        self.process1_up = BlockList[len(patch_size)](self.cfg.hidden_channels, self.cfg.hidden_channels, *_nm(2))
        self.process2_up = BlockList[len(patch_size)](self.cfg.hidden_channels * 2 // factor, self.cfg.hidden_channels * 2 // factor, *_nm(4))
        self.process3_up = BlockList[len(patch_size)](self.cfg.hidden_channels * 4 // factor, self.cfg.hidden_channels * 4 // factor, *_nm(8))
        self.process4_up = BlockList[len(patch_size)](self.cfg.hidden_channels * 8 // factor, self.cfg.hidden_channels * 8 // factor, *_nm(16))
        self.process5_up = BlockList[len(patch_size)](self.cfg.hidden_channels * 16 // factor, self.cfg.hidden_channels * 16 // factor, *_nm(32))
        self.w1_up = ConvList[len(self.padding)](self.cfg.hidden_channels, self.cfg.hidden_channels, 1)
        self.w2_up = ConvList[len(self.padding)](self.cfg.hidden_channels * 2 // factor, self.cfg.hidden_channels * 2 // factor, 1)
        self.w3_up = ConvList[len(self.padding)](self.cfg.hidden_channels * 4 // factor, self.cfg.hidden_channels * 4 // factor, 1)
        self.w4_up = ConvList[len(self.padding)](self.cfg.hidden_channels * 8 // factor, self.cfg.hidden_channels * 8 // factor, 1)
        self.w5_up = ConvList[len(self.padding)](self.cfg.hidden_channels * 16 // factor, self.cfg.hidden_channels * 16 // factor, 1)
        # projectors
        self.fc1 = nn.Linear(self.cfg.hidden_channels, self.cfg.hidden_channels * 2)
        self.fc2 = nn.Linear(self.cfg.hidden_channels * 2, self.cfg.out_channels)


    def forward(self, fx: torch.Tensor, x: torch.Tensor | None = None, T: torch.Tensor | None = None):
        # x: [B, N_points, spatial_dim], fx: [B, N_points, in_channels]
        B, N, _ = fx.shape
        if x is None:
            assert self.cfg.pos_emb
        if x is None or self.cfg.pos_emb:
            x = self.pos.expand(B, -1, -1)    # [B, N_points, ref^d]
        fx = torch.cat((x, fx), -1)
        fx = self.preprocess(fx)    # [B, N_points, hidden_channels]

        # time embedding
        if self.cfg.time_input and (T is not None):
            Time_emb = timestep_embedding(T, self.cfg.hidden_channels).unsqueeze(1).expand(-1, fx.shape[1], -1)    # [B, N_points, hidden_channels]
            Time_emb = self.time_fc(Time_emb)
            fx = fx + Time_emb
        x = fx.permute(0, 2, 1).reshape(B, self.cfg.hidden_channels, *self.cfg.shapelist)
        # ------------------------------------------------------------------------------------------------------
        if not all(item == 0 for item in self.padding):
            if len(self.cfg.shapelist) == 2:
                x = F.pad(x, [0, self.padding[1], 0, self.padding[0]])
            elif len(self.cfg.shapelist) == 3:
                x = F.pad(x, [0, self.padding[2], 0, self.padding[1], 0, self.padding[0]])
        x1 = self.inc(x)
        x1 = F.gelu(self.process1_down(x1) + self.w1_down(x1))
        x2 = self.down1(x1)
        x2 = F.gelu(self.process2_down(x2) + self.w2_down(x2))
        x3 = self.down2(x2)
        x3 = F.gelu(self.process3_down(x3) + self.w3_down(x3))
        x4 = self.down3(x3)
        x4 = F.gelu(self.process4_down(x4) + self.w4_down(x4))
        x5 = self.down4(x4)
        x5 = F.gelu(self.process5_down(x5) + self.w5_down(x5))
        x5 = F.gelu(self.process5_up(x5) + self.w5_up(x5))
        x = self.up1(x5, x4)
        x = F.gelu(self.process4_up(x) + self.w4_up(x))
        x = self.up2(x, x3)
        x = F.gelu(self.process3_up(x) + self.w3_up(x))
        x = self.up3(x, x2)
        x = F.gelu(self.process2_up(x) + self.w2_up(x))
        x = self.up4(x, x1)
        x = F.gelu(self.process1_up(x) + self.w1_up(x))
        x = self.outc(x)

        if not all(item == 0 for item in self.padding):
            if len(self.cfg.shapelist) == 2:
                x = x[..., :-self.padding[0], :-self.padding[1]]
            elif len(self.cfg.shapelist) == 3:
                x = x[..., :-self.padding[0], :-self.padding[1], :-self.padding[2]]
        x = x.reshape(B, self.cfg.hidden_channels, -1).permute(0, 2, 1)
        x = self.fc1(x)
        x = F.gelu(x)
        x = self.fc2(x)
        return x
