import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
import opt_einsum as oe

optimized = True

if optimized:
    contract = oe.contract
else:
    contract = torch.einsum

from src.models.nn import LinearActivation, Activation, DropoutNd
from src.models.sequence.block_fft import BlockFFT
from src.models.sequence.long_conv_kernel import LongConvKernel

class LongConv(nn.Module):
    def __init__(
            self,
            d_model,
            l_max=1024,
            channels=1,
            bidirectional=False,
            # Arguments for position-wise feedforward components
            activation='gelu', # activation between conv and FF
            postact='glu', # activation after FF
            initializer=None, # initializer on FF
            weight_norm=False, # weight normalization on FF
            dropout=0.0, tie_dropout=False,
            transposed=True, # axis ordering (B, L, D) or (B, D, L)
            verbose=False,
            block_fft_conv=False, # replace the FFT conv with Monarch blocks
            block_fft_conv_args={},
            use_gnn=False,
            use_small_gnn=False,
            use_layer_norm_gnn=True,
            use_gcn_true=True,
            use_sequence_layer_norm=True,
            gcn_depth=3,
            i_layer=0,
            nr_layers_with_gnn=6,
            use_set_mixing=False,
            nr_layers_with_set=6,
            set_mixing_architecture="MHA",
            set_mixing_dropout=0.0,
            set_debug=False,
            use_layer_norm_set=False,
            set_feature_embedding_dim=None,
            set_chunk_size=3,
            set_expand=2,
            set_projection=False,
            set_common_pool_embedding_dim=2,
            set_n_attn_summary_statistics=True,
            set_nr_attn_heads=4,
            set_var_layer_norm=False,
            set_v_dim=5,
            buildA_true=True,
            kernel_len=None,

            # SSM Kernel arguments
            **kernel_args,
        ):
        """
        d_state: the dimension of the state, also denoted by N
        l_max: the maximum kernel length, also denoted by L
        channels: can be interpreted as a number of "heads"; the SSM is a map from a 1-dim to C-dim sequence. It's not recommended to change this unless desperate for things to tune; instead, increase d_model for larger models
        bidirectional: if True, convolution kernel will be two-sided

        Position-wise feedforward components:
        --------------------
        activation: activation in between SS and FF
        postact: activation after FF ('id' for no activation, None to remove FF layer)
        initializer: initializer on FF
        weight_norm: weight normalization on FF
        dropout: standard dropout argument. tie_dropout=True ties the dropout mask across the sequence length, emulating nn.Dropout1d

        Other arguments:
        --------------------
        transposed: choose backbone axis ordering of (B, L, H) (if False) or (B, H, L) (if True) [B=batch size, L=sequence length, H=hidden dimension]
        """

        super().__init__()
        if verbose:
            import src.utils.train
            log = src.utils.train.get_logger(__name__)
            log.info(f"Constructing Long Conv (H, L) = ({d_model}, {l_max})")
        
        self.d_model = d_model
        self.H = d_model
        self.L = l_max
        self.bidirectional = bidirectional
        self.channels = channels
        self.transposed = transposed
        self.block_fft_conv = block_fft_conv
        self.block_fft_conv_args = block_fft_conv_args
        self.kernel_len = kernel_len if kernel_len is not None else self.L



        self.D = nn.Parameter(torch.randn(channels, self.H))

        if self.bidirectional:
            channels *= 2

        # SSM Kernel
        # Test#self.kernel = LongConvKernel(self.H, L=self.L, channels=channels, verbose=verbose, **kernel_args)
        self.kernel = LongConvKernel(self.H, L=self.kernel_len, channels=channels, verbose=verbose, **kernel_args)
        
        if self.block_fft_conv:
            self.block_fft_u = BlockFFT(**self.block_fft_conv_args)
            self.block_fft_k = BlockFFT(**self.block_fft_conv_args)
            
        # Pointwise
        self.activation = Activation(activation)
        # dropout_fn = nn.Dropout2d if self.transposed else nn.Dropout # Broken in torch==1.11
        dropout_fn = DropoutNd if tie_dropout else nn.Dropout
        self.dropout = dropout_fn(dropout) if dropout > 0.0 else nn.Identity()

        # position-wise output transform to mix features
        if postact is None:
            self.output_linear = nn.Identity()
        else:
            self.output_linear = LinearActivation(
                self.d_model * self.channels,
                self.d_model,
                # self.H*self.channels,
                # self.d_model*(1 if self.gate is None else self.gate),
                transposed=self.transposed,
                initializer=initializer,
                activation=postact,
                activate=True,
                weight_norm=weight_norm,
            )
        
        # GNN mixing in the batch dimension
        self.use_sequence_layer_norm = use_sequence_layer_norm # True, if the layer norm is applied to the sequence dimension.
        self.use_gcn_true = use_gcn_true
        self.use_gnn = use_gnn
        self.small_gnn = use_small_gnn
        self.use_layer_norm_gnn = use_layer_norm_gnn
        self.gcn_depth = gcn_depth
        self.i_layer = i_layer
        self.nr_layers_with_gnn = nr_layers_with_gnn
        if self.use_gnn and self.i_layer <= self.nr_layers_with_gnn:
            
            from src.models.sequence.gnn import gtnet
            nr_timeseries = 10
            sequence_length = self.L -1
            feature_dim = self.d_model
            # Create a gtnet model
            # Not working

            self.gtnet_model = gtnet(
                gcn_true=self.use_gcn_true, 
                buildA_true=buildA_true, 
                gcn_depth=self.gcn_depth, 
                num_nodes=nr_timeseries, 
                device='cuda:0', 
                predefined_A=None, 
                static_feat=None, 
                dropout=0.3, 
                subgraph_size=nr_timeseries, #smaller or equal to num_nodes 
                node_dim=40, 
                dilation_exponential=1, 
                conv_channels=32, 
                residual_channels=32, 
                seq_length=sequence_length, 
                in_dim=feature_dim,  
                propalpha=0.05, 
                tanhalpha=3, 
                layer_norm_affline=True,
                use_layer_norm=self.use_layer_norm_gnn,
                use_sequence_layer_norm=self.use_sequence_layer_norm
            )
        # Set mixing in the pool dimension:
        self.use_set_mixing = use_set_mixing
        self.nr_layers_with_set = nr_layers_with_set
        self.use_layer_norm_set = use_layer_norm_set
        self.set_mixing_architecture = set_mixing_architecture
        self.set_mixing_dropout = set_mixing_dropout
        self.set_embedding_dim = set_feature_embedding_dim
        self.set_chunk_size = set_chunk_size
        self.expand = set_expand
        self.set_projection = set_projection
        self.common_pool_embedding_dim = set_common_pool_embedding_dim
        self.set_debug = set_debug
        self.n_attn_summary_statistics = set_n_attn_summary_statistics
        self.set_nr_attn_heads = set_nr_attn_heads
        self.set_var_layer_norm = set_var_layer_norm
        self.set_v_dim = set_v_dim
        if self.set_embedding_dim is None:
            self.set_embedding_dim = self.d_model
        if self.use_set_mixing and self.i_layer <= self.nr_layers_with_set:
            import sys
            import os
            SAFARI_PATH = os.environ.get("SAFARI_PATH", None)
            sys.path.append(SAFARI_PATH)
            
            from src.tasks.encoders import SetEncoder
            self.set_encoder = SetEncoder(
                num_states=self.d_model-2,
                loan_pool_size=10,
                d_model=self.d_model,
                common_pool_embedding_dim=self.common_pool_embedding_dim,
                feature_embedding_dim=self.set_embedding_dim,
                debug=self.set_debug,
                architecture=self.set_mixing_architecture,
                use_layer_norm_set=self.use_layer_norm_set,
                chunk_size=self.set_chunk_size,
                nr_attention_heads=self.set_nr_attn_heads, #4
                n_attn_summary_statistics=self.n_attn_summary_statistics,
                dropout=self.set_mixing_dropout,
                expand=self.expand,
                projection=self.set_projection,
                set_var_layer_norm = self.set_var_layer_norm,
                #use_layer_norm_set=False
                )



    def forward(self, u, state=None, rate=1.0, lengths=None, **kwargs): # absorbs return_output and transformer src mask
        """
        u: (B H L) if self.transposed else (B L H)
        state: (H N) never needed, remnant from state spaces repo

        Returns: same shape as u
        """
        if self.use_gnn and self.i_layer <= self.nr_layers_with_gnn:
            u = self.gtnet_model(u)
        
        if self.use_set_mixing and self.i_layer <= self.nr_layers_with_set:
            #u shape: (B*nr_units, L, H)
            
            u = u.transpose(-1, -2)
            # u shape: (B*nr_units, H, L)
            #u = u.transpose(0,1)

            # (1, B*nr_units, H, L)
            #u = u.unsqueeze(0)  # TODO instead of just unsqueeze, we should use the batch size
            # (B, nr_units, H, L)
            u  = u.reshape(-1, kwargs["nr_units"], u.shape[1], u.shape[2])
            # (B, H, nr_units, L)
            u = u.transpose(1,2)
            # can the number of units go into the state variable? Let's try
            u, _ = self.set_encoder(u)

        if not self.transposed: u = u.transpose(-1, -2)
        L = u.size(-1)

        
        # Mask out padding tokens
        # TODO handle option for mask - instead of lengths, which assumes suffix padding
        if isinstance(lengths, int):
            if lengths != L:
                lengths = torch.tensor(lengths, dtype=torch.long, device=u.device)
            else:
                lengths = None
        if lengths is not None:
            assert isinstance(lengths, torch.Tensor) and lengths.ndim == 1 and lengths.size(0) in [1, u.size(0)]
            mask = torch.where(torch.arange(L, device=lengths.device) < lengths[:, None, None], 1., 0.)
            u = u * mask

        # Compute SS Kernel
        L_kernel = L if self.L is None else min(L, round(self.L / rate))
        L_kernel = self.kernel_len
        k, _ =  self.kernel(L=L_kernel, rate=rate, state=state) # (C H L) (B C H L)
        # Test
        # Convolution
        if self.bidirectional:
            k0, k1 = rearrange(k, '(s c) h l -> s c h l', s=2)
            k = F.pad(k0, (0, L)) \
                    + F.pad(k1.flip(-1), (L, 0))

        if self.block_fft_conv:
            k_f = self.block_fft_k(k.to(torch.complex64), N=L_kernel+L) # (C H L)
            u_f = self.block_fft_u(u.to(torch.complex64), N=L_kernel+L) # (B H L)
            y_f = contract('bhl,chl->bchl', u_f, k_f)
            if self.learn_ifft:
                y = self.block_fft_u(y_f, N=L_kernel+L,forward=False).real[..., :L]
            else:
                y = torch.fft.ifft(y_f, n=L_kernel+L, dim=-1).real[..., :L] # (B C H L)
        else:
            k_f = torch.fft.rfft(k, n=L_kernel+L) # (C H L)
            u_f = torch.fft.rfft(u, n=L_kernel+L) # (B H L)
            y_f = contract('bhl,chl->bchl', u_f, k_f)
            y = torch.fft.irfft(y_f, n=L_kernel+L)[..., :L] # (B C H L)

        # Compute skip connection
        y = y + contract('bhl,ch->bchl', u, self.D)

        # Reshape to flatten channels
        y = rearrange(y, '... c h l -> ... (c h) l')

        if not self.transposed: y = y.transpose(-1, -2)
        y = self.activation(y)
        y = self.dropout(y)
        y = self.output_linear(y)
        # Assert y is not complex
        assert not torch.is_complex(y), f"y became complex: dtype={y.dtype}"

        return y, None

    @property
    def d_state(self):
        return self.H

    @property
    def d_output(self):
        return self.d_model




