import os
import argparse
import logging
import time
import numpy as np
import numpy.random as npr
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from memKNO.network import FourierNet
from torch.nn.parameter import Parameter
import math
from torch import nn
from torch.nn import init
from torch import Tensor
from torch.nn.parameter import Parameter
from typing import Tuple

#####################################################################
# Phase I: latent process -> mean spatial field w.r.t. the slow scale
#####################################################################
class FieldDecoder(nn.Module):
    def __init__(self,
                 x_grid: torch.Tensor,    
                 fourier_hidden_dim: int,
                 code_dim: int,
                 n_fourier_layers: int = 3,
                 input_scale: float = 256.0,
                 chunk_t: int = 0,
                 mlp_in: bool = False,
                 mlp_layers: int = 2,
                 mlp_act: str = "gelu",
                 use_sigmoid: bool = False,
                 **kwargs):
        """
        Args:
            x_grid: [h, w, grid_dim]
            chunk_t: batch number of chunk_t codes at each decoding pass
        """
        super(FieldDecoder, self).__init__()
        self.register_buffer('x_grid', x_grid)
        self.hidden_dim = fourier_hidden_dim
        self.code_dim = code_dim
        self.chunk_t = chunk_t
        self.use_sigmoid = use_sigmoid
        self.mlp_in = mlp_in

        if self.mlp_in:
            self.mlp_proj = MLP(
                in_dim=code_dim, hidden_dim=code_dim*2, out_dim=code_dim, 
                num_layers=mlp_layers, nl=mlp_act, 
                use_layernorm=True, norm_where="both"
            )

        self.net = FourierNet(
            grid_dim=x_grid.shape[-1],
            hidden_feat_dim=self.hidden_dim,
            code_dim=self.code_dim,
            out_dim=1,
            n_layers=n_fourier_layers,
            input_scale=input_scale
        )


    def forward(self, codes: torch.Tensor, *, 
                x_grid_override: torch.Tensor | None = None) -> torch.Tensor:
        """
        codes: [B, T, S, code_dim]
        x_grid_override:  [H_new, W_new, grid_dim]  (optional)
        """
        if self.chunk_t is None or self.chunk_t <= 0:
            return self._run_net(codes, x_grid_override)   # [B,T,H,W,S]

        _, T, _, _ = codes.shape
        outs = []
        for start in range(0, T, self.chunk_t):
            end = min(T, start + self.chunk_t)
            outs.append(self._run_net(codes[:, start:end], x_grid_override))
        out = torch.cat(outs, dim=1)                       # [B,T,H,W,S]

        return torch.sigmoid(out) if self.use_sigmoid else out


    def _run_net(self, code_slice: torch.Tensor, x_grid_override: torch.Tensor | None = None) -> torch.Tensor:
        grid = self.x_grid if x_grid_override is None else x_grid_override
        # grid: [H, W, 2] ; code_slice: [B, t, S, code_dim]
        if self.mlp_in:
            code_shape = list(code_slice.shape)
            code_flat = code_slice.reshape(-1, self.code_dim)
            code_proj = self.mlp_proj(code_flat)
            code_slice = code_proj.view(code_shape)
            
        return self.net(grid, code_slice)                  # [B,t,H,W,S]
    



from memKNO.network import GaussianFourierFeatureTransform, FeedForward
from memKNO.attention import CrossLinearAttention
class CrossFormerDecoder2D(nn.Module):
    def __init__(self,
                 x_grid: torch.Tensor,    # [..., 2]
                 atten_blocks: int,
                 heads: int,
                 latent_channels: int,
                 out_channels: int,
                 random_fourier_scale: float = 8.0,
                 mlp_hidden_dim: int = 256,
                 dropout: float = 0.0
                 ):
        super(CrossFormerDecoder2D, self).__init__()
        self.register_buffer('x_grid', x_grid)    # [H, W, 2]
        self.H, self.W = self.x_grid.shape[0], self.x_grid.shape[1]
        self.latent_channels = latent_channels
        self.out_channels = out_channels
    
        self.coordinate_proj = nn.Sequential(
            GaussianFourierFeatureTransform(2, self.latent_channels//2, scale=random_fourier_scale),
            nn.Linear(self.latent_channels, self.latent_channels, bias=False),
            nn.GELU(),
            nn.Linear(self.latent_channels, self.latent_channels, bias=False)
        )

        # cross-attention blocks
        self.cross_atten_blocks = nn.ModuleList([])
        for _ in range(atten_blocks):
            cross = CrossLinearAttention(
                dim_q=latent_channels, dim_kv=latent_channels,
                heads=heads, dim_head=latent_channels//heads,
                relative_emb=False, cat_pos=False, init_params=True
            )
            ffn = FeedForward(latent_channels, mlp_hidden_dim, dropout=dropout)
            block = nn.ModuleList([nn.LayerNorm(latent_channels), cross, 
                                   nn.LayerNorm(latent_channels), ffn])
            self.cross_atten_blocks.append(block)

        self.out_proj = nn.Sequential(
            nn.LayerNorm(self.latent_channels),
            nn.Linear(self.latent_channels, self.latent_channels//2, bias=False),
            nn.GELU(),
            nn.Linear(self.latent_channels//2, self.latent_channels//2, bias=False),
            nn.GELU(),
            nn.Linear(self.latent_channels//2, self.out_channels, bias=True)
        )


    def forward(self, codes: torch.Tensor) -> torch.Tensor:
        """
        codes: [B, K, latent_channels] (latent_tokens)
        """
        query_loc = self.x_grid.flatten(0, 1).unsqueeze(0).expand(codes.shape[0], -1, -1)    # [B, H*W, 2]
        query_emb = self.coordinate_proj(query_loc)    # [B, H*W, latent_channels]
        
        for blk in self.cross_atten_blocks:
            ln1, cross, ln2, ffn = blk
            query_emb = ln1(query_emb)
            query_emb = cross(query_emb, codes) + query_emb
            query_emb = ln2(query_emb)
            query_emb = ffn(query_emb) + query_emb
        
        out = self.out_proj(query_emb)
        out = out.unflatten(1, (self.H, self.W))    # [B, H, W, out_channels]
        return out 


#####################################################################
# Fouriernet-based decoder
#####################################################################
from memKNO.network import MLP, ConditionalFourierLayer

class LatentModulation(nn.Module):
    def __init__(self, pos_feat_dim: int, latent_dim: int, out_feat_dim: int, device=None, dtype=None,
                 *, mlp_layers: int = 2, mlp_act: str = "gelu"):
        """
        pos_in: [B, N_pts, pos_feat_dim]
        latent_feat: [B, latent_dim]
        out: [B, N_pts, out_feat_dim]
        """
        factory_kwargs = {'device': device, 'dtype': dtype}
        super(LatentModulation, self).__init__()

        self.pos_feat_dim = pos_feat_dim
        self.latent_dim = latent_dim
        self.out_feat_dim = out_feat_dim

        self.A = nn.Parameter(torch.empty(out_feat_dim, pos_feat_dim, **factory_kwargs))
        self.B = nn.Parameter(torch.empty(out_feat_dim, latent_dim, **factory_kwargs))
        self.mlp_modulation = MLP(
            in_dim=latent_dim, hidden_dim=latent_dim*2, out_dim=out_feat_dim, 
            num_layers=mlp_layers, nl=mlp_act, 
            last_bias=False, last_kaiming=False, last_kaiming_a=math.sqrt(5), last_zero_init=True,
            use_layernorm=True, norm_where="pre"
        )
        self.bias = nn.Parameter(torch.empty(out_feat_dim, **factory_kwargs))

        self.reset_parameters()


    def reset_parameters(self) -> None:
        bound = 1 / math.sqrt(self.pos_feat_dim)
        init.kaiming_uniform_(self.A, a=math.sqrt(5))
        init.kaiming_uniform_(self.B, a=math.sqrt(5))
        init.uniform_(self.bias, -bound, bound)

    
    def forward(self, pos_feat: torch.Tensor, latent_feat: torch.Tensor) -> torch.Tensor:
        """
        pos_feat: [B, N_pts, pos_feat_dim]
        latent_feat: [B, latent_dim]
        out: [B, N_pts, out_feat_dim]
        """
        pos_flat = pos_feat.reshape(-1, self.pos_feat_dim)    # [B*N_pts, pos_feat_dim]
        pos_proj = pos_flat @ self.A.t()    # [B*N_pts, out_feat_dim]
        pos_proj = pos_proj.view(pos_feat.shape[0], pos_feat.shape[1], self.out_feat_dim)

        latent_proj = latent_feat @ self.B.t()
        latent_proj = latent_proj.unsqueeze(1).expand(-1, pos_feat.shape[1], -1)
        latent_proj_res = self.mlp_modulation(latent_feat)    # [B, out_feat_dim]
        latent_proj_res = latent_proj_res.unsqueeze(1).expand(-1, pos_feat.shape[1], -1)

        return pos_proj + latent_proj +  latent_proj_res + self.bias.view(1, 1, -1)

                 
        
class FourierDecoder(nn.Module):
    def __init__(self, grid_dim: int, fourier_hidden_dim: int, latent_dim: int, out_dim: int, n_fourier_layers: int = 3,
                 input_scale: float = 256.0,
                 *, modmlp_layers: int = 2, modmlp_act: str = "gelu",
                 use_freq_scale: bool = True, use_phase: bool = False, use_coord_scale: bool = False,
                 lora_rank: int = 0, lora_scale: float = 1.0,):
        
        super(FourierDecoder, self).__init__()
        self.n_layers = n_fourier_layers
        self.hidden_feat_dim = fourier_hidden_dim
        
        self.mods = nn.ModuleList(
            [LatentModulation(grid_dim, latent_dim, self.hidden_feat_dim, mlp_layers=modmlp_layers, mlp_act=modmlp_act)] +
            [LatentModulation(self.hidden_feat_dim, latent_dim, self.hidden_feat_dim, mlp_layers=modmlp_layers, mlp_act=modmlp_act)
             for _ in range(int(n_fourier_layers))]
        )

        self.out_proj = nn.Linear(self.hidden_feat_dim, out_dim)

        freq_mod_params = {
            "use_freq_scale": use_freq_scale, "use_phase": use_phase, "use_coord_scale": use_coord_scale,
            "lora_rank": lora_rank, "lora_scale": lora_scale
        }
        self.filters = nn.ModuleList(
            [ConditionalFourierLayer(grid_dim, fourier_hidden_dim, input_scale / np.sqrt(n_fourier_layers+1), 
                                     z_dim=latent_dim, **freq_mod_params) for _ in range(n_fourier_layers+1)]
        )

    
    def forward(self, grid: torch.Tensor, latent_feat: torch.Tensor) -> torch.Tensor:
        """
        Inputs:
        - grid: [N_pt, grid_dim]
        - latent_feat: [B, latent_dim]
        Outputs:
        - out: [B, N_pts, out_dim]
        """
        bs = latent_feat.shape[0]
        pos_emb0 = self.filters[0](grid)
        if pos_emb0.dim() == 2:
            pos_emb0 = pos_emb0.unsqueeze(0).expand(bs, -1, -1)  # [B, N_pt, hidden_feat_dim]
        pos_feat = torch.zeros(bs, *grid.shape, device=latent_feat.device)
        # print(pos_feat.shape, latent_feat.shape)
        hidden_feat0 = self.mods[0](
            pos_feat=pos_feat, latent_feat=latent_feat
        )
        out = pos_emb0 * hidden_feat0    # [B, N_pt, hidden_feat_dim]

        for i in range(1, self.n_layers + 1):
            pos_embi = self.filters[i](grid)
            if pos_embi.dim() == 2:
                pos_embi = pos_embi.unsqueeze(0).expand(bs, -1, -1)  # [B, N_pt, hidden_feat_dim]
            hidden_feati = self.mods[i](pos_feat=out, latent_feat=latent_feat)
            out = pos_embi * hidden_feati
        
        out = self.out_proj(out)    # [B, N_pt, out_dim]
        return out
        
