import copy

import torch
import torch.nn as nn
import torch.nn.functional as F

from einops import rearrange, repeat
from collections import OrderedDict

from models.nn.components import NonLinear, LinearActivation, FeedForward
from models.functional.krylov import krylov
from models.functional.companion_krylov import companion_krylov
from models.spacetime.kernel_companion import *
# from models.spacetime.kernel_diagonal import *



def process_network_config(config, args=None):
    layer_count = 0
    for layer, layer_config in config['layers'].items():
        try:
            n_hippos = 0
            for name, n_hippo in layer_config['ssm'].items():
                if name[0] == 'n':
                    n_hippos += int(n_hippo)
            config['layers'][layer]['n_hippos'] = n_hippos
        except:
            pass
        if layer not in ['_name_', 'encoder', 'decoder']:
            if 'lr' not in layer_config:
                layer_config['layer_lr'] = config['lr']
            if 'd_ffn' not in layer_config:
                layer_config['layer_d_ffn'] = config['d_ffn']
            if 'recurrent' not in layer_config['ssm']:
                layer_config['ssm']['recurrent'] = False
            if 'n_ma_error' not in layer_config['ssm']:
                layer_config['ssm']['n_ma_error'] = 0
            
            for matrix_init in ['a_init', 'b_init', 'c_init']:
                if matrix_init not in layer_config['ssm']:
                    layer_config['ssm'][matrix_init] = None
                    
            if 'memory_norm' not in layer_config['ssm']:
                layer_config['ssm']['memory_norm'] = args.memory_norm  # 0, 1
                
            
            layer_count += 1
    config['n_layers'] = layer_count
    
    # Add argparse args
    if args is not None:
        config['lag'] = args.lag
        config['horizon'] = args.horizon
        
    if 'learn_dt' not in config['spacetime_args']:
        config['spacetime_args']['learn_dt'] = False
    return config


class TwoInputSequential(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.modules = nn.ModuleList(*args)

    def forward(self, x, v=0):
        for i, module in enumerate(self.modules):
            x, v = module(x, v)
        return x, v
    
    
class mySequential(nn.Sequential):
    def forward(self, *inputs):
        for module in self._modules.values():
            if type(inputs) == tuple:
                inputs = module(*inputs)
            else:
                inputs = module(inputs)
        return inputs


def linear_layer(input_dim, output_dim, 
                 weight=None, bias=None, train=False):
    layer = nn.Linear(input_dim, output_dim)
    with torch.no_grad():
        if weight is not None:
            layer.weight.copy_(
                weight.type(layer.weight.type()).view(*layer.weight.shape)
            )
        try:
            if bias is not None:
                layer.bias.copy_(
                    bias.type(layer.bias.type()).view(*layer.bias.shape)
                )
        except:
            breakpoint()
        if train:
            layer.weight.requires_grad = True
            layer.bias.requires_grad = True
        else:
            layer.weight.requires_grad = False
            layer.bias.requires_grad = False
    return layer


def get_encoder(input_dim, output_dim, method='identity', train=False):
    if method == 'identity':
        weight = torch.ones((input_dim, output_dim))
        bias = torch.zeros((output_dim))
    else:
        weight = None
        bias = None
        
    return linear_layer(input_dim, output_dim, weight, bias, train)


def get_decoder(input_dim, output_dim, method='linear', train=True):
    decoder = get_encoder(input_dim, output_dim, method, train)
    with torch.no_grad():
        decoder.bias.fill_(0.)
        decoder.bias.requires_grad = False
    return decoder


class SpaceTime(nn.Module):
    def __init__(self,
                 input_dim,
                 output_dim,
                 n_layers,
                 channels,
                 d_state,
                 lag,
                 **network_config):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        
        self.n_layers = n_layers
        self.channels = channels
        self.d_state  = d_state
        
        self.horizon = network_config['horizon']
        self.lag = lag
        
        try:
            self.multivariate = network_config['multivariate']
        except:
            self.multivariate = 1  # No MIMO for now
        
        self.nn = self._make_layers(**network_config)
        
    def reset_horizon(self, horizon):
        self.horizon = horizon
        
    def reset_lag(self, lag):
        self.lag = lag
        
    def _make_layers(self, **network_config):
        network_config['spacetime_args']['horizon'] = self.horizon  
        n_hippos_per_layer = []
        layers = []
        for layer, layer_config in network_config['layers'].items():
            if layer != 'encoder' and layer != 'decoder' and layer != '_name_':
                n_hippos_per_layer.append(layer_config['n_hippos'])
                layers.append(
                    SpaceTimeLayer(self.channels, self.d_state, self.lag,
                                   **layer_config, **network_config['spacetime_args'])  
                )
                
        self.encoder = get_encoder(self.input_dim, 
                                   n_hippos_per_layer[0] // self.multivariate,
                                   **network_config['layers']['encoder'])
        self.decoder = get_decoder(n_hippos_per_layer[-1] // self.multivariate,
                                   self.output_dim,
                                   **network_config['layers']['decoder'])
        
        return mySequential(*layers)
    
    def forward(self, x, horizon=0):
        x = self.encoder(x)
        x, _ = self.nn(x)
        x = self.decoder(x)
        return x
    
    def encode(self, x):
        x = self.encoder(x)
        x = self.nn(x)
        return x
    
    def decode(self, x):
        return self.decoder(x)
    
    def train_ssm(self):
        """Toggle trainable parameters"""
        for layer in self.nn:
            for kernel_key, kernel in layer.ssm_block.kernels.items():
                kernel.train_ssm()
                
    def freeze_ssm(self):
        """Toggle trainable parameters"""
        for layer in self.nn:
            for kernel_key, kernel in layer.ssm_block.kernels.items():
                kernel.freeze_ssm()
            
    def train_feedback(self):
        """Toggle trainable parameters"""
        for layer in self.nn:
            for kernel_key, kernel in layer.ssm_block.kernels.items():
                kernel.train_feedback()
            
    def freeze_feedback(self):
        """Toggle trainable parameters"""
        for layer in self.nn:
            for kernel_key, kernel in layer.ssm_block.kernels.items():
                kernel.freeze_feedback()
                
                
    def process_lag(self):
        for layer in self.nn:
            for kernel_key, kernel in layer.ssm_block.kernels.items():
                if kernel.recurrent is True:
                    kernel.process_lag()
                else:
                    # Don't train prior layer kernels
                    for p in kernel.parameters():
                        p.requires_grad = False
            """
            Hard-coded hell
            (output_linear): TransposedLinear()
            (f): Sequential(
              (0): GELU()
              (1): DropoutNd()
              (2): TransposedLinear()
            )
            """
            try:
                for p in layer.post_ssm_layer.output_linear.parameters():
                    p.requires_grad = False
            except:
                pass
            try:  # Hardcoded hell 
                for p in layer.post_ssm_layer.f[2].parameters():
                    p.requires_grad = False
            except:
                pass
                    
                    
    def process_horizon(self, train_k):
        for layer in self.nn:
            for kernel_key, kernel in layer.ssm_block.kernels.items():
                if kernel.recurrent is True:
                    kernel.process_horizon(train_k)
                else:
                    kernel.train_ssm()  # Train all layer kernels
                    
            """
            Hard-coded hell
            (output_linear): TransposedLinear()
            (f): Sequential(
              (0): GELU()
              (1): DropoutNd()
              (2): TransposedLinear()
            )   
            # Test with: [f for f in model.nn[2].post_ssm_layer.f[2].parameters()]
            """
            try:
                for p in layer.post_ssm_layer.output_linear.parameters():
                    p.requires_grad = True
            except:
                pass
            try:  # Hardcoded hell 
                for p in layer.post_ssm_layer.f[2].parameters():
                    p.requires_grad = True
            except:
                pass
            
                
    def get_feedback(self):
        for layer in self.nn:  # Assumes only 1 recurrent layer
            for kernel_key, kernel in layer.ssm_block.kernels.items():
                if kernel.recurrent is True:
                    return kernel.feedback_inputs, kernel.reference_inputs
                    
    
    
class MultiHorizonSpaceTime(nn.Module):
    def __init__(self,
                 input_dim,
                 output_dim,
                 n_layers,
                 channels,
                 d_state,
                 lag,
                 **network_config):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        
        self.n_layers = n_layers
        self.channels = channels
        self.d_state  = d_state
        
        self.horizon = network_config['horizon']
        self.lag = lag
        
        try:
            self.multivariate = network_config['multivariate']
        except:
            self.multivariate = 1  # No MIMO for now
        
        self.nn = self._make_layers(**network_config)
        
    def reset_horizon(self, horizon):
        self.horizon = horizon
        
    def reset_lag(self, lag):
        self.lag = lag
        
    def _make_layers(self, **network_config):
        network_config['spacetime_args']['horizon'] = self.horizon  
        n_hippos_per_layer = []
        layers = []
        for layer, layer_config in network_config['layers'].items():
            if layer != 'encoder' and layer != 'decoder' and layer != '_name_':
                n_hippos_per_layer.append(layer_config['n_hippos'])
                layers.append(
                    SpaceTimeLayer(self.channels, self.d_state, self.lag,
                                   **layer_config, **network_config['spacetime_args'])  
                )
                
        self.encoder = get_encoder(self.input_dim, 
                                   n_hippos_per_layer[0] // self.multivariate,
                                   **network_config['layers']['encoder'])
        self.decoder = get_decoder(n_hippos_per_layer[-1] // self.multivariate,
                                   self.output_dim,
                                   **network_config['layers']['decoder'])
            
        return mySequential(*layers)
    
    def forward(self, x, horizon=0):
        x = self.encoder(x)
        x, _ = self.nn(x, horizon)
        x = self.decoder(x)
        return x
    
    def encode(self, x):
        x = self.encoder(x)
        x = self.nn(x)
        return x
    
    def decode(self, x):
        return self.decoder(x)
    
    def train_ssm(self):
        """Toggle trainable parameters"""
        for layer in self.nn:
            for kernel_key, kernel in layer.ssm_block.kernels.items():
                kernel.train_ssm()
                
    def freeze_ssm(self):
        """Toggle trainable parameters"""
        for layer in self.nn:
            for kernel_key, kernel in layer.ssm_block.kernels.items():
                kernel.freeze_ssm()
            
    def train_feedback(self):
        """Toggle trainable parameters"""
        for layer in self.nn:
            for kernel_key, kernel in layer.ssm_block.kernels.items():
                kernel.train_feedback()
            
    def freeze_feedback(self):
        """Toggle trainable parameters"""
        for layer in self.nn:
            for kernel_key, kernel in layer.ssm_block.kernels.items():
                kernel.freeze_feedback()
                
    def process_lag(self, freeze_weights=False):
        for layer in self.nn:
            for kernel_key, kernel in layer.ssm_block.kernels.items():
                if kernel.recurrent is True:
                    kernel.process_lag()
                elif freeze_weights:
                    # Don't train prior layer kernels
                    for p in kernel.parameters():
                        p.requires_grad = False
            """
            Hard-coded hell
            (output_linear): TransposedLinear()
            (f): Sequential(
              (0): GELU()
              (1): DropoutNd()
              (2): TransposedLinear()
            )
            """
            if freeze_weights:
                try:
                    for p in layer.post_ssm_layer.output_linear.parameters():
                        p.requires_grad = False
                except:
                    pass
                try:  # Hardcoded hell 
                    for p in layer.post_ssm_layer.f[2].parameters():
                        p.requires_grad = False
                except:
                    pass
                    
                    
    def process_horizon(self, train_k=False):
        for layer in self.nn:
            for kernel_key, kernel in layer.ssm_block.kernels.items():
                if kernel.recurrent is True:
                    kernel.process_horizon(train_k)
                else:
                    kernel.train_ssm()  # Train all layer kernels
                    
            """
            Hard-coded hell
            (output_linear): TransposedLinear()
            (f): Sequential(
              (0): GELU()
              (1): DropoutNd()
              (2): TransposedLinear()
            )   
            # Test with: [f for f in model.nn[2].post_ssm_layer.f[2].parameters()]
            """
            try:
                for p in layer.post_ssm_layer.output_linear.parameters():
                    p.requires_grad = True
            except:
                pass
            try:  # Hardcoded hell 
                for p in layer.post_ssm_layer.f[2].parameters():
                    p.requires_grad = True
            except:
                pass
            
    def joint_train(self, freeze_weights=False):
        for layer in self.nn:
            for kernel_key, kernel in layer.ssm_block.kernels.items():
                if kernel.recurrent is True:
                    kernel.joint_train = True
            
                
                    
    def get_feedback(self):
        for layer in self.nn:  # Assumes only 1 recurrent layer
            for kernel_key, kernel in layer.ssm_block.kernels.items():
                if kernel.recurrent is True:
                    return kernel.feedback_inputs, kernel.reference_inputs
    
    
    
class SpaceTimeLayer(nn.Module):

    def __init__(self,
                 channels,
                 d_state,
                 lag,
                 n_hippos,
                 post_ssm,
                 ffn,
                 skip_connection,
                 layer_lr,
                 layer_d_ffn,
                 ssm,
                 **spacetime_args):
        super().__init__()
        
        self.channels = channels
        self.h = n_hippos
        
        # Linear layer parameters
        self.dropout = spacetime_args['dropout']
        self.postact = spacetime_args['postact']
        self.activation = spacetime_args['activation']
        self.initializer = spacetime_args['initializer']
        self.weight_norm = spacetime_args['weight_norm']
        
        # Layer after Hippos
        self.post_ssm = post_ssm
        # Feedforward network after Hippos
        self.feedforward = ffn
        # FFN parameters
        self.d_ffn = layer_d_ffn
        # skip connection
        self.skip_connection = skip_connection
        
        kernel_args = {'channels': channels, 'd_state': d_state, 'lr': layer_lr}
        # Hack
        for k in ['learn_theta', 'theta_scale', 'use_initial', 'trap_rule', 
                  'learn_dt', 'dt_min', 'dt_max', 'unconstrained_a', 'horizon']:
            kernel_args[k] = spacetime_args[k]

        self.ssm_block = SSMBlock(n_hippos, lag, **ssm, **kernel_args)
        
        if self.post_ssm == 'identity':
            self.post_ssm_layer = nn.Identity()
        elif self.post_ssm == 'linear':
            self.post_ssm_layer = LinearActivation(self.channels * self.h, self.h, 
                                                   bias=True,
                                                   zero_bias_init=False,
                                                   transposed=True,
                                                   activation=self.activation,
                                                   initializer=self.initializer,
                                                   activate=False,
                                                   weight_norm=self.weight_norm)
        elif self.post_ssm == 'nonlinear':
            self.post_ssm_layer = NonLinear(self.h, self.channels, 
                                            dropout=self.dropout, 
                                            postact=self.postact, 
                                            activation=self.activation,
                                            initializer=self.initializer,
                                            weight_norm=self.weight_norm)
        else:
            raise NotImplementedError
        
        if self.feedforward == 'identity':
            self.ffn = nn.Identity()
        else:
            self.ffn = FeedForward(self.h, self.channels, 
                                   d_hidden=self.d_ffn,                                   
                                   dropout=self.dropout,
                                   transposed=True,  
                                   activation=self.activation,
                                   initializer=self.initializer,
                                   weight_norm=self.weight_norm)

    def forward(self, u, v=0):
        """u: (B, L, H)"""
        x = rearrange(u, 'b l h -> b h l')
        x = self.ssm_block(x, v)  # v is an extra constant term; could be helpful as a mean thing
        try:
            x = self.post_ssm_layer(x)
        except:
            breakpoint()
        x = self.ffn(x)  
        x = rearrange(x, 'b h l -> b l h')
        if self.skip_connection:
            x = x + u
        return x, v
    
    
    
class SSMBlock(nn.Module):
    def __init__(self,
                 n_hippos, 
                 lag,
                 n_ar,
                 n_ma,
                 n_diff,
                 n_simple,
                 n_simple_discrete,
                 n_shift,
                 n_companion,
                 n_ma_error,
                 **kernel_args):  # channels, d_state, lr
        super().__init__()
        assert n_hippos == (n_ar + n_ma + n_diff + 
                            n_simple + n_simple_discrete + 
                            n_shift + n_companion + n_ma_error)
        
        print('Lag:', lag)
        
        self.kernels = nn.ModuleDict()
        
        # ------- "AR" and "MA" kernels ---------
        
        if n_ar > 0:
            ar_d_state = lag if n_ar > 1 else None
            _kwargs = copy.deepcopy(kernel_args)
            ssm = ARKernel(n_hippos=n_ar,
                           ar_d_state=ar_d_state,  
                           **_kwargs)
            self.kernels.update({'ar': ssm})  # preserve ordering
            print(f'AR ssm d_state: {ssm.d_state}')
            del _kwargs  # hacks
            
        if n_ma > 0:
            _kwargs = copy.deepcopy(kernel_args)
            ssm = CompleteMAKernel(n_hippos=n_ma,  # CompleteMAKernel
                                   ma_d_state=lag,
                                   **_kwargs)
            self.kernels.update({'ma': ssm})
            print(f'MA ssm d_state: {ssm.d_state}')
            del _kwargs  # hacks
            
            
        # ------- "Data preprocessing" kernels ---------
        
        if n_diff > 0 and n_ma_error == 0:
            _kwargs = copy.deepcopy(kernel_args)
            ssm = DifferencingKernel(n_hippos=n_diff,
                                     diff_d_state=lag,  # kernel_args['d_state']
                                     **_kwargs)
            self.kernels.update({'diff': ssm})
            del _kwargs  # hacks
            
            try:
                print(f'-> Differencing kernel weights:')
                print(self.kernels['diff'].c[0, :13, :10])
            except Exception as e:
                print(e)
                pass
            
        elif n_ma_error > 0 and n_diff == 0:
            ssm = MAErrorKernel(n_hippos=n_ma_error,
                                ma_d_state=lag,
                                **kernel_args)
            self.kernels.update({'ma_error': ssm})
            
            try:
                print(f'-> MA Error kernel weights:')
                print(self.kernels['ma_error'].c[0, :13, :10])
            except Exception as e:
                print(e)
                pass
            
        elif n_diff > 0 and n_ma_error > 0:  # Combine into a single kernel 
            ssm = ResidualKernel(n_diff, n_ma_error, 
                                 ma_d_state=lag, num_orders=4, 
                                 **kernel_args)
            self.kernels.update({'diff_error': ssm})
            
            
        # ------- Diagonal SpaceTime kernels ---------
            
        if n_simple > 0:
            ssm = ContinuousDiagonalKernel(n_hippos=n_simple,
                                           **kernel_args)
            self.kernels.update({'diagonal_cont': ssm})
        
        if n_simple_discrete > 0:
            ssm = DiscreteDiagonalKernel(n_hippos=n_simple_discrete,
                                         **kernel_args)
            self.kernels.update({'diagonal_disc': ssm})
            
            
        # ------- Companion matrix kernels ---------
            
        if n_shift > 0:
            ssm = ShiftKernel(n_hippos=n_shift,
                              **kernel_args)
            self.kernels.update({'shift': ssm})
        
        if n_companion > 0:
            ssm = CompanionKernel(n_hippos=n_companion,
                                  **kernel_args)
            self.kernels.update({'companion': ssm})
        
        
    def forward(self, u, v=0):
        # Assumes u is B x H x L
        
        # Probably can be optimized, but right now sequentially 
        # computes forward pass for each hippo then concatenates them together
        y = []; 
        hippo_ix = 0
        for k, ssm in self.kernels.items():
            _y, _v = ssm(u[:, hippo_ix:hippo_ix + ssm.n_hippos, :], v=v)
            y.append(_y)  # _v is either 0 or B x H x L
            hippo_ix += ssm.n_hippos            
        return torch.cat(y, dim=1)  # just ignore V for now
            
        
        