import torch
import torch.nn as nn
from timm.models.layers import trunc_normal_
from layers.Basic import MLP, LinearAttention
from layers.Embedding import timestep_embedding, unified_pos_embedding

from baselines.model_factory import GKTConfig


class Galerkin_Transformer_block(nn.Module):
    """Transformer encoder block."""

    def __init__(
            self,
            num_heads: int,
            hidden_dim: int,
            dropout: float,
            act='gelu',
            mlp_ratio=4,
            last_layer=False,
            out_dim=1,
    ):
        super().__init__()
        self.last_layer = last_layer
        self.ln_1 = nn.LayerNorm(hidden_dim)
        self.ln_1a = nn.LayerNorm(hidden_dim)
        self.Attn = LinearAttention(hidden_dim, heads=num_heads, dim_head=hidden_dim // num_heads,
                                    dropout=dropout, attn_type='galerkin')
        self.ln_2 = nn.LayerNorm(hidden_dim)
        self.mlp = MLP(hidden_dim, hidden_dim * mlp_ratio, hidden_dim, n_layers=0, res=False, act=act)
        if self.last_layer:
            self.ln_3 = nn.LayerNorm(hidden_dim)
            self.mlp2 = nn.Linear(hidden_dim, out_dim)

    def forward(self, fx):
        fx = self.Attn(self.ln_1(fx), self.ln_1a(fx)) + fx
        fx = self.mlp(self.ln_2(fx)) + fx
        if self.last_layer:
            return self.mlp2(self.ln_3(fx))
        else:
            return fx


class Model(nn.Module):
    def __init__(self, model_cfg: GKTConfig):
        super(Model, self).__init__()
        self.__name__ = 'Galerkinformer'
        self.cfg = model_cfg
        self.spatial_dim = len(self.cfg.shapelist)

        ## embedding
        if self.cfg.pos_emb:
            pos = unified_pos_embedding(self.cfg.shapelist, self.cfg.ref, device=self.cfg.device)  # [1,N,ref^d]
            self.register_buffer('pos', pos, persistent=False)
            self.preprocess = MLP(self.cfg.in_channels + self.cfg.ref ** self.spatial_dim, self.cfg.hidden_channels * 2,
                                  self.cfg.hidden_channels, n_layers=0, res=False, act=self.cfg.activation)
        else:
            self.preprocess = MLP(self.cfg.in_channels + self.spatial_dim, self.cfg.hidden_channels * 2, self.cfg.hidden_channels,
                                  n_layers=0, res=False, act=self.cfg.activation)
        if self.cfg.time_input:
            self.time_fc = nn.Sequential(nn.Linear(self.cfg.hidden_channels, self.cfg.hidden_channels), nn.SiLU(),
                                         nn.Linear(self.cfg.hidden_channels, self.cfg.hidden_channels))

        ## models
        self.blocks = nn.ModuleList([Galerkin_Transformer_block(num_heads=self.cfg.n_heads, hidden_dim=self.cfg.hidden_channels,
                                                                dropout=self.cfg.dropout,
                                                                act=self.cfg.activation,
                                                                mlp_ratio=self.cfg.mlp_ratio,
                                                                out_dim=self.cfg.out_channels,
                                                                last_layer=(_ == self.cfg.n_layers - 1))
                                     for _ in range(self.cfg.n_layers)])
        self.placeholder = nn.Parameter((1 / (self.cfg.hidden_channels)) * torch.rand(self.cfg.hidden_channels, dtype=torch.float))
        self.initialize_weights()

    def initialize_weights(self):
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=0.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, (nn.LayerNorm, nn.BatchNorm1d)):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, fx: torch.Tensor, x: torch.Tensor | None = None, T: torch.Tensor | None = None):
        # fx: [B, N_points, in_channels]  -> [B, N_points, out_channels]
        B, N, _ = fx.shape
        if x is None:
            assert self.cfg.pos_emb
        if x is None or self.cfg.pos_emb:
            x = self.pos.expand(B, -1, -1)    # [B, N_points, ref^d]
        fx = torch.cat((x, fx), -1)
        fx = self.preprocess(fx)    # [B, N_points, hidden_channels]
        fx = fx + self.placeholder[None, None, :]
        # time embedding
        if self.cfg.time_input and (T is not None):
            Time_emb = timestep_embedding(T, self.cfg.hidden_channels).unsqueeze(1).expand(-1, fx.shape[1], -1)    # [B, N_points, hidden_channels]
            Time_emb = self.time_fc(Time_emb)
            fx = fx + Time_emb
        for block in self.blocks:
            fx = block(fx)
        return fx