import torch
import torch.nn as nn
from models.s4.s4 import S4Block
from models.s4.s4nd import S4ND
from models.fno_blocks import FNO1dBlock, FNO2dBlock
from models.ffno_blocks import FSpectralConv1d, FSpectralConv2d
from models.transformer_block import TransformerBlock
from models.lstm_block import LSTM_Block
from models.fast_model import fast_input_layer, fast_output_layer
import torch.nn.functional as F
from models.custom_layers import IO, GridIO, get_residual_layer, get_norm_layer, get_ffn_layer, act_registry
from utils.utilities3 import is_iterable
from einops import rearrange

from utils.log_utils import get_logger
import logging

log = get_logger(__name__, level = logging.INFO)

from functools import partial

from models.s4_model import  S4BaseModel

# class MultiInputFFNO(nn.Module):
#     def __init__(
#         self,
#         d_input,
#         history_len,
#         d_output=10,
#         d_model=256,
#         n_layers=4,
#         exo_dropout=0.0,
#         prenorm=False,
#         interlayer_act=None,
#         input_processor="Concat",
#         output_processor="identity",
#         residual_type="identity",
#         layer_processor=None,
#         fast={},
#         s4block_args={},
#         n_dim=1,
#         final_mlp_hidden_expansion=None,
#         norm_type="LayerNorm",
#         final_mlp_act = "gelu",
#         ffn_type = "zero",
#         encoder_kernel_size = 1,
#     ):
#         '''S4 Base Model
#         :param exo_dropout: dropout rate outside the S4Block (layer-level dropout)
#         :param s4block_args: arguments for the S4Block, standard S4Block if empty'''
#         super().__init__()

#         self.prenorm = prenorm

#         self.io = GridIO(input_processor, output_processor)

#         d_input = d_input * history_len

#         if fast.get("use_fast",False): 
#             self.encoder = fast_input_layer(kernel_size=fast["kernel_size"], stride=fast["stride"], in_channels=d_input, out_channels=d_model, n_dim=n_dim)
#             self.decoder = fast_output_layer(kernel_size=fast["kernel_size"], stride=fast["stride"], in_channels=d_model, out_channels=d_output, n_dim=n_dim, 
#                                              final_mlp_hidden_expansion=final_mlp_hidden_expansion, final_mlp_act = final_mlp_act)
#         else: 
#             if n_dim == 1: 
#                 self.encoder = Encoder(d_input, d_model, kernel_size=encoder_kernel_size)
#             else: 
#                 self.encoder = nn.Linear(d_input, d_model)

#             if final_mlp_hidden_expansion is None:
#                 self.decoder = nn.Linear(d_model, d_output)
#             else: 
#                 self.decoder = MLP(d_model, d_output, final_mlp_hidden_expansion*d_model, act = final_mlp_act)
        
#         s4blocks = get_s4block(n_layers, s4block_args)
#         assert len(s4blocks) == n_layers, "Number of S4 blocks does not match number of layers"

#         # Stack S4 layers as residual blocks
#         self.s4_layers = nn.ModuleList()
#         self.norms = nn.ModuleList()
#         self.dropouts = nn.ModuleList()
#         self.residuals = nn.ModuleList()
#         self.ffns = nn.ModuleList()
#         self.ffns_norm = nn.ModuleList()
#         norm_types = extend_values(norm_type, n_layers)
#         residual_types = extend_values(residual_type, n_layers)
#         ffn_types = extend_values(ffn_type, n_layers)
#         for s4b, norm_type, residual_type, ffn_type in zip(s4blocks, norm_types, residual_types, ffn_types):
#             self.s4_layers.append(
#                 s4b()
#             )
#             self.norms.append(get_norm_layer(norm_type, d_model))
#             self.dropouts.append(nn.Dropout(exo_dropout))
#             self.residuals.append(get_residual_layer(residual_type, d_model))
#             self.ffns.append(get_ffn_layer(ffn_type, d_model))
#             self.ffns_norm.append(get_norm_layer(norm_type, d_model))

#         # Interlayer activation
#         self.interlayer_act = act_registry[interlayer_act] if interlayer_act is not None else nn.Identity()

#         if layer_processor is None: 
#             self.layer_input_processors = ["identity" for _ in range(n_layers)]
#             self.layer_output_processors = ["identity" for _ in range(n_layers)]
#         else: 
#             in_layer = layer_processor[0] # list of input processors
#             out_layer = layer_processor[1] # list of output processors
#             assert n_layers % len(in_layer)==0 and n_layers % len(in_layer)==0, "Number of layer processors is not a divisor of number of layers"
#             log.info(f"n layers: {n_layers}, layer processors provided: {len(in_layer)}")
#             self.layer_input_processors = list(in_layer) * (n_layers // len(in_layer))
#             self.layer_output_processors = list(out_layer) * (n_layers // len(out_layer))
#         self.layer_processor = nn.ModuleList([IO(in_l,out_l) for in_l, out_l in zip(self.layer_input_processors, self.layer_output_processors)])


class MultiInputFFNO(S4BaseModel):
    def __init__(
        self,
        history_len,
        layer_input_processors,
        layer_output_processors,
        spatial_shape, # e.g. (1024,) in 1D or (128,128) in 2D
        d_model=128,
        d_output=1,
        n_layers=4,
        exo_dropout=0.0,
        prenorm=False,
        interlayer_act=None,
        s4block_args={},
        input_processor="ConcatTrans",
        output_processor="identity",
        step_input_processor="ConcatTrans",
        step_output_processor="identity",
        residual_type="identity",
        use_spatial_batch=True,
        n_states = 1,
        fast = {},
        n_dim = 1,
        norm_type = "LayerNorm",
        final_mlp_hidden_expansion = None,
        final_mlp_act = "gelu",
        ffn_type = "zero",
        encoder_kernel_size = 1,
        **kwargs,
    ):  
        n_dim = len(spatial_shape)
        d_input = n_states*history_len + n_dim  # + n_dim for the grid (1 in 1D, 2 in 2D, 3 in 3D)
        self.history_len = history_len
        bidirectionals = s4block_args.get("bidirectional", [False] * n_layers)
        for bidirectional, layer_i in zip(bidirectionals, layer_input_processors):
            if layer_i != "BatchTime": # TimeBatch (Time is flattened into Batch) is the only valid input processor for bidirectional layers
                assert not bidirectional, "Bidirectional must be False when processing time dimension for causality"
        
                # spatial shape might change because of stride / kernel of fast mode

        if not is_iterable(s4block_args.get("modes", [])):
            s4block_args["modes"] = [s4block_args["modes"]] * n_layers
        
        if -1 in s4block_args.get("modes", []):
            # log.info(f"Mode set to -1, setting to {int(self.spatial_shape[0]/2)}")
            for i, m in enumerate(s4block_args["modes"]):
                
                if m == -1:
                    if fast.get("use_fast",False):
                        s4block_args["modes"][i] = int(spatial_shape[0]/(2*fast["stride"])+1)
                    else: 
                        s4block_args["modes"][i] = int(spatial_shape[0]/2+1)

        
        layer_processor = (layer_input_processors, layer_output_processors)
        super().__init__(
            d_input,
            d_output=d_output,
            d_model=d_model,
            n_layers=n_layers,
            exo_dropout=exo_dropout,
            prenorm=prenorm,
            interlayer_act=interlayer_act,
            s4block_args=s4block_args,
            input_processor=input_processor,
            output_processor=output_processor,
            residual_type=residual_type,
            layer_processor=layer_processor, 
            fast=fast, 
            n_dim=n_dim,
            norm_type=norm_type,
            final_mlp_hidden_expansion=final_mlp_hidden_expansion,
            final_mlp_act = final_mlp_act,
            ffn_type=ffn_type,
            encoder_kernel_size=encoder_kernel_size,
        )
        self.use_spatial_batch = use_spatial_batch
        self.step_io = GridIO(step_input_processor, step_output_processor)

        if fast.get("use_fast",False): 
            dummy_input = torch.zeros(1, *spatial_shape, d_input)
            dummy_output = self.encoder(dummy_input)
            self.spatial_shape = dummy_output.shape[1:-1]
        else: 
            self.spatial_shape = spatial_shape


    def forward(self, x, grid, batch_dt = None):
        """
        Input x is shape (B, Sx, [Sy], [Sz], [T], V)
        """
        inp = x.unfold(-2, self.history_len, 1) # (B, Sx, [Sy], [Sz], T, V) -> (B, Sx, [Sy], [Sz], T, H, V)
        inp = rearrange(inp, '... t h v -> ... t (h v)')
        return torch.cat([x[..., 1:self.history_len, :], self._forward(inp, grid, batch_dt)], dim=-2)

    def _forward(self, x, grid, batch_dt = None, **kwargs):
        """
        Input x is shape (B, Sx, [Sy], [Sz], [T], V)
        """
        x = self.io.process_input(x, grid)
        x = self.encoder(x)  

        n = len(self.s4_layers)
        for i, (layer, norm, dropout, layer_io, residual, ffn, ffn_norm) in enumerate(zip(self.s4_layers, self.norms, self.dropouts, self.layer_processor, self.residuals, self.ffns, self.ffns_norm)):

            z = x
            if self.prenorm:
                # Prenorm
                z = norm(z)

            # Input process it (normally identity)
            z = layer_io.process_input(z)

            # Apply S4 block: we ignore the state input and output
            if batch_dt is not None:
                batch_dt = batch_dt.mean()
            z, _ = layer(z, batch_dt = batch_dt)
            # Output process it (normally identity)
            z = layer_io.process_output(z)

            # Dropout on the output of the S4 block
            z = dropout(z)

            # Residual connection
            x = z + residual(x)

            if not self.prenorm:
                # Postnorm
                x = norm(x)
            
            # FFN   
            if self.prenorm: 
                x = ffn_norm(x)
            
            x = ffn(x) + x

            if not self.prenorm: 
                x = ffn_norm(x)
            
            x = self.interlayer_act(x)

        # Decode the outputs
        x = self.decoder(x)
        x = self.io.process_output(x)
        return x

    def predict_with_history(self,
                x, # (B, S, history_len, V)
                grid, 
                n_timesteps,
                train_timesteps,
                **kwargs):  
        ys = x[..., 1:, :]
        x = rearrange(x, 'b s t v -> b s 1 (t v)')
        for t in range(0, n_timesteps - self.history_len + 1):
            y = self._forward(x, grid, **kwargs)
            ys = torch.cat([ys, y], dim=-2)
            x = torch.cat([x[...,1:], y], dim=-1)
        return ys
        
        