from models.s4_model import S4BaseModel
import torch
from utils.utilities3 import is_iterable
from models.custom_layers import GridIO

from utils.log_utils import get_logger, add_file_handler
import logging

log = get_logger(__name__, level = logging.INFO)

class DualTransformer(S4BaseModel):
    def __init__(
        self,
        layer_input_processors,
        layer_output_processors,
        spatial_shape, # e.g. (1024,) in 1D or (128,128) in 2D
        d_model=128,
        d_output=1,
        n_layers=4,
        exo_dropout=0.0,
        prenorm=False,
        interlayer_act=None,
        s4block_args={},
        input_processor="ConcatTrans",
        output_processor="identity",
        step_input_processor="ConcatTrans",
        step_output_processor="identity",
        residual_type="identity",
        use_spatial_batch=True,
        n_states = 1,
        fast = {},
        n_dim = 1,
        norm_type = "LayerNorm",
        final_mlp_hidden_expansion = None,
        final_mlp_act = "gelu",
        ffn_type = "zero",
        encoder_kernel_size = 1,
        **kwargs,
    ):  
        n_dim = len(spatial_shape)
        
                # spatial shape might change because of stride / kernel of fast mode
        if fast.get("use_fast",False): 
            dummy_input = torch.zeros(1, *spatial_shape, d_input)
            dummy_output = self.encoder(dummy_input)
            self.spatial_shape = dummy_output.shape[1:-1]
        else: 
            self.spatial_shape = spatial_shape

        if not is_iterable(s4block_args.get("modes", [])):
            s4block_args["modes"] = [s4block_args["modes"]] * n_layers
        
        if -1 in s4block_args.get("modes", []):
            log.info(f"Mode set to -1, setting to {int(self.spatial_shape[0]/2)}")
            for i, m in enumerate(s4block_args["modes"]):
                
                if m == -1:
                    s4block_args["modes"][i] = int(self.spatial_shape[0]/2)

        d_input = n_states + n_dim # + n_dim for the grid (1 in 1D, 2 in 2D, 3 in 3D)
        layer_processor = (layer_input_processors, layer_output_processors)
        super().__init__(
            d_input,
            d_output=d_output,
            d_model=d_model,
            n_layers=n_layers,
            exo_dropout=exo_dropout,
            prenorm=prenorm,
            interlayer_act=interlayer_act,
            s4block_args=s4block_args,
            input_processor=input_processor,
            output_processor=output_processor,
            residual_type=residual_type,
            layer_processor=layer_processor, 
            fast=fast, 
            n_dim=n_dim,
            norm_type=norm_type,
            final_mlp_hidden_expansion=final_mlp_hidden_expansion,
            final_mlp_act = final_mlp_act,
            ffn_type=ffn_type,
            encoder_kernel_size=encoder_kernel_size,
        )
        self.use_spatial_batch = use_spatial_batch
        self.step_io = GridIO(step_input_processor, step_output_processor)
    
    def predict(self, x0, grid, n_timesteps, train_timesteps, reset_memory = True, LG_length = None, batch_dt = None, discard_state = False):
        """
        Input x is shape (B, Sx, [Sy], [Sz], V)
        Output: (B, Sx, [Sy], [Sz], T, V)
        """
        #autoreggresive inference
        y = x0.unsqueeze(-2)
        # n_timesteps includes x0
        for t in range(n_timesteps):
            y = torch.cat([y, self.forward(y,grid,batch_dt)[...,-1:,:]], dim=-2)
        return y[...,1:,:]

