"""
Hedgehog Block
- Spikey multiplicative interactions (Hedgehog Attention)
- Temporal Q, K, V (via SSMs)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F

import opt_einsum as oe
from einops import rearrange, repeat
from src.models.nn.components import Activation, DropoutNd

from .ssm.hedgehog import init_ssm

class OptimModule(nn.Module):
    def __init__(self): 
        super().__init__()

    def register(self, name, tensor, trainable=False, lr=None, wd=None):
        """Utility method: register a tensor as a buffer or trainable parameter"""
        if trainable:
            try:
                self.register_parameter(name, nn.Parameter(tensor))
            except KeyError:
                delattr(self, name)
                self.register_parameter(name, nn.Parameter(tensor))
        else:
            try:
                self.register_buffer(name, tensor)
            except KeyError:
                delattr(self, name)
                self.register_buffer(name, tensor)

        optim = {}
        if trainable and lr is not None: optim["lr"] = lr
        if trainable and wd is not None: optim["weight_decay"] = wd
        if len(optim) > 0: setattr(getattr(self, name), "_optim", optim)

# --------------
# Hedgehog Block
# --------------
class HedgehogBlock(nn.Module):
    """
    Base Hedgehog block
    """
    def __init__(self,
                 d_model: int,
                 conv_config: dict,
                 attention_config: dict,
                 ffn_config: dict,
                 **kwargs):
        super().__init__()
        # Should be n_kernels * n_heads_per_kernel * head_dim
        self.model_dim = d_model
        # Apply 1D conv over input sequence to compute Q, K, V
        self.conv = init_ssm(conv_config['method'])(**conv_config['kwargs'])
        self.attention = HedgehogAttention(**attention_config['kwargs'])
        
        # Output dimension of block
        self.d_output = self.model_dim
    
    def forward(self, x: torch.Tensor, state=None, lengths=None, **kwargs) -> torch.Tensor:
        """
        In this implementation, assume "self-attention" only,
        i.e., q, k, v all computed from the same input sequence.
        """
        b, l, d = x.shape
        
        # Mask out padding tokens -> borrowed from long-convs
        if isinstance(lengths, int):
            if lengths != l:
                lengths = torch.tensor(lengths, dtype=torch.long, device=x.device)
            else:
                lengths = None
        if lengths is not None:
            assert isinstance(lengths, torch.Tensor) and lengths.ndim == 1 and lengths.size(0) in [1, x.size(0)]
            mask = torch.where(torch.arange(l, device=lengths.device) < lengths[:, None, None], 1., 0.)
            x = x * mask
        
        x = self.conv(x)        # Apply SSMs for 1D convolution to compute Q, K, V
        x = self.attention(x)   # Apply spiked linear attention, x.shape is (b, l, d)
        # x = self.ffn(x)       # MLP following block like in Transformers (2-layer with nonlinearity in b/t by default)
        return x
    

# ----------------------    
# Spiky Linear Attention
# ----------------------
class HedgehogAttention(OptimModule):
    """
    Hedgehog spiky linear self-attention
    """
    def __init__(self, 
                 n_kernels: int,
                 n_heads: int=None,    # Number of heads per kernel
                 head_dim: int=None,   # Dimension of each head
                 model_dim: int=None,  # Dimension of layer inputs and outputs
                 bidirectional: bool=False,
                 attention: str='spiked_relu_lse',  # spiked_relu_lse for +stability but -expressivity?
                 affine_qkv: bool=True,
                 temperature_qkv: float=1.,
                 context_len: int=None,
                 dropout: float=0.,
                 layernorm: bool=False,
                 skip_connection: bool=True,
                 linear_bias: bool=True,
                 output_method: str='mixer',
                 output_mix_head_pattern: str='nk (nh hd)',  # Which dimensions to compute
                 output_mix_kernel_pattern: str='nh hd nk',
                 qkv_mlp_kwargs: dict=None):
        super().__init__()
        # At least one of these should be int
        assert not (n_heads is None and head_dim is None)
                 
        # Initialize dimensions
        # -> model_dim = n_kernels * n_heads_per_kernel * head_dim
        self.n_kernels = n_kernels
        dims = self.init_heads(n_heads, head_dim, model_dim)
        self.head_dim, self.n_heads, self.model_dim = dims
        self.dim_kwargs = {
            'nk': self.n_kernels, 'nh': self.n_heads, 'hd': self.head_dim
        }
        
        # Initialize MLPs to compute Q, K, V
        mlp_dim = self.n_kernels * self.head_dim
        mlp_qkv = self.init_qkv_mlps(mlp_dim, qkv_mlp_kwargs)
        self.q_mlp, self.k_mlp, self.v_mlp = mlp_qkv
        
        self.bidirectional   = bidirectional
        self.attention       = attention
        self.affine_qkv      = affine_qkv
        self.temperature_qkv = temperature_qkv
        self.context_len     = context_len
        self.dropout         = dropout
        self.layernorm       = layernorm
        self.skip_connection = skip_connection
        self.linear_bias     = linear_bias
        
        # Dropout
        self.dropout = nn.Dropout(self.dropout) if self.dropout > 0 else nn.Identity()
        
        # Layernorm
        self.layernorm = nn.LayerNorm(self.model_dim) if self.layernorm else nn.Identity()
        
        # Output FFN -> hack for now; hardcode args
        if output_method == 'mixer':
            kwargs = {'n_kernels': self.n_kernels, 'n_heads': self.n_heads, 
                      'head_dim': self.head_dim, 'bias': False,
                      'mix_head_pattern': output_mix_head_pattern,
                      'mix_kernel_pattern': output_mix_kernel_pattern}
        else:  # Linear projection
            kwargs = {'input_dim': self.model_dim, 'output_dim': self.model_dim, 
                      'bias': False, 'n_layers': 1, 'n_activations': 0}
        self.ffn = init_ffn(output_method)(**kwargs) 
        
        if self.skip_connection:
            skip = torch.randn(self.model_dim)  # replicate 5
            self.register('skip', skip, trainable=True, lr=None, wd=None)
            
          
    def forward(self, x: torch.tensor) -> torch.tensor:
        """
        "Self-attention" implementation where x is a single B x L x D torch.Tensor
        """
        b, l, d = x.shape  # batch_size, sample_len, sample_dim
                           # d is n_kernels * n_heads * head_dim
        dims = {'b': b, 'l': l, 
                'nk': self.n_kernels, 'nh': self.n_heads, 'hd': self.head_dim}
            
        # Compute Q, K, V by mixing across input dims
        # -> We just compute from x for q, k, and v assuming self-attention for now
        q = self.compute_qkv(x, self.q_mlp)  # shape is (b, l, d)
        k = self.compute_qkv(x, self.k_mlp)  # shape is (b, l, d)
        v = self.compute_qkv(x, self.v_mlp)  # shape is (b, l, d)
        
        # If self.layernorm is True, apply before transform to help with stability
        q = self.layernorm(q)
        k = self.layernorm(k)
        v = self.layernorm(v)
        
        # Repeat head dimensions
        q = self.repeat_dims(q, value=False)  # shape is (b, l, nk * nh * hd * hd)
        k = self.repeat_dims(k, value=False)  # shape is (b, l, nk * nh * hd * hd)
        v = self.repeat_dims(v, value=True)   # shape is (b, l, nk * nh * hd * hd)
        
        # Compute (spiked linear) attention
        # e.g., q = e^(relu(q)) - 1 for spiked relu
        q, k = self.transform_qk(q, k)
        
        # Compute dot products b/t q and kv heads
        a = self.multiply_and_contract(q, k * v)  # a.shape = (b, l, d)
        
        # Normalize by sum over sequence of dot products b/t q and k heads
        if self.affine_qkv:
            a /= self.multiply_and_contract(q, k)
        
        # Mix across heads, apply dropout, add skip connection
        a = self.ffn(a)
        a = self.dropout(a)
        if self.skip_connection:  # Treat first x as input
            a = a + oe.contract('b l d, d -> b l d', x, self.skip)
        return a, None
    
    def init_heads(self, n_heads: int, head_dim: int, model_dim: int):
        """
        Initializes / figures out how to divide up the input dims (== model_dim) into
        number of kernels, number of heads per kernel, dim of each head
        -> model_dim should be precomputed as n_kernels * n_heads * head_dim?
        """
        if head_dim is None:
            self.head_dim  = self.model_dim // (self.n_kernels * n_heads)
            self.n_heads   = n_heads
        elif n_heads is None:
            self.head_dim  = head_dim
            self.n_heads   = self.model_dim // (self.n_kernels * head_dim)
        else:
            self.head_dim  = head_dim
            self.n_heads   = n_heads
            self.model_dim = self.head_dim * self.n_heads * self.n_kernels
        return self.head_dim, self.n_heads, self.model_dim
           
    def init_qkv_mlps(self, input_dim: int, mlp_args: dict):
        """
        Returns MLPs for computing Q, K, V from input sequence
        """
        mlp_args['input_dim']  = input_dim
        mlp_args['output_dim'] = input_dim
        mlps = (HedgehogMLP(**mlp_args), 
                HedgehogMLP(**mlp_args), 
                HedgehogMLP(**mlp_args))  # one for each Q, K, V
        return mlps
    
    def compute_qkv(self, x: torch.tensor, mlp) -> torch.tensor:
        """
        Computes Q, K, V from input sequence x (where x.shape is b, l, d)
        """
        x = rearrange(x, 'b l (nk nh hd) -> b l nh (nk hd)',
                      nk=self.n_kernels, nh=self.n_heads, hd=self.head_dim)
        x = mlp(x)
        x = rearrange(x, 'b l nh (nk hd) -> b l (nk nh hd)',
                      nk=self.n_kernels, nh=self.n_heads, hd=self.head_dim)
        return x
    
    def repeat_dims(self, x: torch.tensor, value: bool=False) -> torch.tensor:
        """
        Repeats dimensions from b, l, (nk nh hd) -> b, l (nk nh hd hd). 
        Starting from last dims as [1, 2, 3, ..., hd], 
        - If value is True, we get: [1, ..., 1,  2, ..., 2,  ..., hd, ..., hd]
        - Otherwise, we get:        [1, ..., hd, 1, ..., hd, ..., 1,  ..., hd]
        """
        x = rearrange(x, 'b l (nk nh hd) -> b l nk nh hd',
                      nk=self.n_kernels, nh=self.n_heads, hd=self.head_dim)
        if value:
            x = repeat(x, 'b l nk nh hd -> b l nk nh (hd r)', r=self.head_dim)
        else:
            x = repeat(x, 'b l nk nh hd -> b l nk nh (r hd)', r=self.head_dim)
        return rearrange(x, 'b l nk nh hd -> b l (nk nh hd)')
    
    def multiply_and_contract(self, q: torch.tensor, 
                              k: torch.tensor) -> torch.tensor:
        # Assume q, k are shape (b, l, d*), 
        # d* = n_kernels * n_heads_per_kernel * head_dim
        b, l, _ = q.shape
        # Compute sum_{i=1}^L k_ij * v_ij for all j in [1, ..., d]
        if self.bidirectional: 
            k = torch.sum(k, dim=1, keepdim=True)  # shape is (b, 1, d*)
        else: 
            k = torch.cumsum(k, dim=1)  # shape is (b, l, d*)
        # Multiplicative interactions
        a = q * k  # shape is (b, l, d*)
        # Contract back to shape (b, l, d)
        return rearrange(a, 'b l (nk nh r hd) -> b l nk nh r hd', 
                         **self.dim_kwargs).sum(dim=-1).view(b, l, -1)
    
    def transform_qk(self, q, k):
        # Warning - this is not stable yet; things may go to NaN
        if self.attention == 'spiked_relu':
            # Numerical stability issues
            q = torch.clamp(torch.exp(F.relu(q, inplace=True)) - 1,
                            min=1e-8, max=None)  
            k = torch.clamp(torch.exp(F.relu(k, inplace=True)) - 1,
                            min=1e-8, max=None)
        elif self.attention == 'spiked_relu_lse':
            # Numerical stability issues - try logsumexp + clamp
            # -> May hurt expressivity?
            q = F.relu(q, inplace=True)
            k = F.relu(k, inplace=True)
            q = torch.clamp(torch.exp(q - q.max()), 
                            min=1e-8, max=1) * self.temperature_qkv  
            k = torch.clamp(torch.exp(k - k.max()), 
                            min=1e-8, max=1) * self.temperature_qkv
        elif self.attention == 'spiked_relu_diff':
            # Numerical stability issues
            # Zero-out if q and k are different enough
            q = F.relu(q * (1. - q + k))
            k = F.relu(k * (1. - k + q))  # if q and k are diff signs, one of these will be 0
            q = torch.clamp(torch.exp(q) - 1, min=1e-8)
            k = torch.clamp(torch.exp(k) - 1, min=1e-8)
        elif self.attention == 'spiked_relu_diff_lse':
            # Numerical stability issues - try logsumexp + clamp
            # Zero-out if q and k are different enough
            q = F.relu(q * (1. - q + k))
            k = F.relu(k * (1. - k + q))
            q = torch.clamp(torch.exp(q - q.max()), 
                            min=1e-8, max=1) * self.temperature_qkv
            k = torch.clamp(torch.exp(k - k.max()), 
                            min=1e-8, max=1) * self.temperature_qkv
        
        # Some other pre-dot-product transforms
        elif self.attention == 'relu':
            q = F.relu(q)
            k = F.relu(k)
        elif self.attention == 'spiked':
            q = torch.exp(q)
            k = torch.exp(k)
        elif self.attention == 'softmax':
            q = self._softmax_qk(q)
            k = self._softmax_qk(k)
        elif self.attention == 'norm':
            q = self._norm_qk(q)
            k = self._norm_qk(k)
        elif self.attention == 'linear':
            q = q
            k = k
        else:
            raise NotImplementedError(f'{self.attention} attention not implemented')
        return q, k
        
    def _softmax_qk(self, x: torch.tensor) -> torch.tensor:
        x = rearrange(x, 'b l (nk nh r hd) -> b l nk nh r hd',
                      nk=self.n_kernels, nh=self.n_heads, hd=self.head_dim)
        x = F.softmax(x, dim=-1)
        x = rearrange(x, 'b l nk nh r hd -> b l (nk nh r hd)')
        return x
    
    def _norm_qk(self, x: torch.tensor) -> torch.tensor:
        x = rearrange(x, 'b l (nk nh r hd) -> b l nk nh r hd',
                      nk=self.n_kernels, nh=self.n_heads, hd=self.head_dim)
        x = F.normalize(x, p=2, dim=-1)
        x = rearrange(x, 'b l nk nh r hd -> b l (nk nh r hd)')
        return x
    
# -------------------    
# Feed-forward layers
# -------------------
def init_ffn(method):
    if method == 'mlp':
        return HedgehogMLP
    elif method == 'mixer':
        return HedgehogMixer
    elif method == 'identity':
        return nn.Identity
    else:
        raise NotImplementedError(f'{method} feedforward network not implemented!')

    
class HedgehogMLP(nn.Module):
    """
    One of two ways to mix across multiple kernels ('heads') after 
    Hedgehog multiplicative interactions (spiky linear attention)
    - Also general fully-connected MLP layer
    """
    def __init__(self,
                 input_dim: int,                 
                 output_dim: int,
                 hidden_dim: int=None,
                 activation: str=None,
                 bias: bool=True,
                 dropout: float=0.,
                 layernorm: bool=False,
                 n_layers: int=1,
                 n_activations: int=0,
                 pre_activation: bool=False,
                 input_shape: str='bld',
                 skip_connection: bool=False,
                 average_pool: str=None):
        super().__init__()
        self.input_dim     = input_dim
        self.hidden_dim    = hidden_dim
        self.output_dim    = output_dim
        self.input_shape   = input_shape
        
        self.activation      = activation
        self.bias            = bias
        self.dropout         = dropout
        self.layernorm       = nn.LayerNorm(input_dim) if layernorm else nn.Identity()
        self.n_layers        = n_layers
        self.n_activations   = n_activations
        self.pre_activation  = pre_activation
        self.skip_connection = skip_connection
        self.average_pool    = average_pool
        
        self.initialize_layers()
        
    def initialize_layers(self):
        n_layers_to_init = self.n_layers
        n_activations_to_init = self.n_activations
        
        if self.n_layers < 2:
            self.hidden_dim = self.output_dim
        elif self.hidden_dim is None:
            self.hidden_dim = self.input_dim
        else:   # Redundant but explicit
            self.hidden_dim = self.hidden_dim
            
        # Add layers
        if self.n_activations > self.n_layers or self.pre_activation:
            layers = [Activation(self.activation)]
            n_activations_to_init -= 1
        else:
            layers = []
            
        while n_layers_to_init > 0 or n_activations_to_init > 0:
            if n_layers_to_init == self.n_layers:
                layers.append(nn.Linear(self.input_dim, self.hidden_dim,
                                        bias=self.bias))
            elif n_layers_to_init > 1:
                layers.append(nn.Linear(self.hidden_dim, self.hidden_dim,
                                        bias=self.bias))
            elif n_layers_to_init == 1:
                layers.append(nn.Linear(self.hidden_dim, self.output_dim,
                                        bias=self.bias))
            if n_activations_to_init > 0:
                layers.append(Activation(self.activation))
            n_layers_to_init -= 1
            n_activations_to_init -= 1
            
        layers.append(self.init_dropout())
        self.layers = nn.Sequential(*layers)

        
    def init_dropout(self):
        if self.dropout > 1:  # Dropout hack for now, testing DropoutNd
            return DropoutNd(p=self.dropout-1.)
        elif self.dropout > 0:
            return nn.Dropout(self.dropout)
        else:
            return nn.Identity()
        
        
    def forward(self, x):
        x = self.layernorm(x)
        
        if self.input_shape == 'bdl':
            x = rearrange(x, 'b d l -> b l d')
        
        if self.skip_connection:
            # Layernorm with skip connection
            x = self.layers(x) + x  
        else: 
            x = self.layers(x)
        
        if self.average_pool == 'l':
            x = x.mean(dim=1, keepdim=True)
        return x


class HedgehogMixer(nn.Module):
    """
    One of two ways to mix across multiple kernels ('heads') after 
    Hedgehog multiplicative interactions (spiky linear attention)
    """
    def __init__(self,
                 n_kernels: int,
                 n_heads: int,
                 head_dim: int,
                 bias: bool=False,
                 mix_head_pattern: str='nk (nh hd)',
                 mix_kernel_pattern: str='nh hd nk'):
        super().__init__()
        self.n_kernels  = n_kernels
        self.n_heads    = n_heads
        self.head_dim   = head_dim
        self.bias       = bias
        self.mix_head_pattern   = mix_head_pattern
        self.mix_kernel_pattern = mix_kernel_pattern
        
        # Dimensions for rearranging input
        self.dims = {'nk': self.n_kernels, 
                     'nh': self.n_heads, 
                     'hd': self.head_dim}
        # Initialize linear weights
        self.init_weights()
        
        
    def init_weights(self):
        # Mix across dims in a head -> Computes from x.shape = b, l, nk, nh * hd
        self.mix_head = nn.Linear(self.n_heads * self.head_dim, 
                                  self.n_heads * self.head_dim,
                                  bias=self.bias)
        # Mix across dims from different kernels -> Computes from x.shape = b, l, nh, hd, nk
        self.mix_kernel = nn.Linear(self.n_kernels,
                                    self.n_kernels,
                                    bias=self.bias)
        # Compositions of the above
        self.mix_head_kernel = HeadKernelFFN(self.n_kernels, self.n_heads, self.head_dim,
                                             self.mix_head, self.mix_kernel,
                                             self.mix_head_pattern,
                                             self.mix_kernel_pattern)
        self.mix_kernel_head = KernelHeadFFN(self.n_kernels, self.n_heads, self.head_dim,
                                             self.mix_head, self.mix_kernel,
                                             self.mix_head_pattern,
                                             self.mix_kernel_pattern)
        
    def forward(self, x: torch.tensor) -> torch.tensor:
        """
        x.shape is b, l, n_kernels * n_heads * head_dims
        """
        b, l, d = x.shape
        # Rearrange input dims for network
        x_hd = rearrange(x, f'b l (nk nh hd) -> b l {self.mix_head_pattern}', 
                         **self.dims)
        x_nk = rearrange(x, f'b l (nk nh hd) -> b l {self.mix_kernel_pattern}',
                         **self.dims)
        
        # Outputs for layers with input.shape ... nh * hd
        x_hd = rearrange(self.mix_head_kernel(x_hd),
                         f'b l {self.mix_kernel_pattern} -> b l (nk nh hd)', **self.dims)
        x_nk = rearrange(self.mix_kernel_head(x_nk), 
                         f'b l {self.mix_head_pattern} -> b l (nk nh hd)', **self.dims)
        # Combine with skip connection
        return x + x_hd + x_nk
    
    
class HeadKernelFFN(nn.Module):
    def __init__(self, n_kernels: int, n_heads: int, head_dim: int,
                 mix_head: nn.Linear, mix_kernel: nn.Linear,
                 mix_head_pattern: str='nk (nh hd)',
                 mix_kernel_pattern: str='nh hd nk'):
        super().__init__()
        self.dims = {'nk': n_kernels, 
                     'nh': n_heads, 
                     'hd': head_dim}
        
        self.mix_head   = mix_head
        self.mix_kernel = mix_kernel
        self.mix_head_pattern   = mix_head_pattern
        self.mix_kernel_pattern = mix_kernel_pattern
        
    def forward(self, x: torch.tensor) -> torch.tensor:
        # Assume x.shape is (..., n_heads * head_dim)
        x = self.mix_head(x)
        x = rearrange(x, f'b l {self.mix_head_pattern} -> b l {self.mix_kernel_pattern}',
                      **self.dims)
        x = self.mix_kernel(x)
        return x
    
    
class KernelHeadFFN(nn.Module):
    def __init__(self, n_kernels: int, n_heads: int, head_dim: int,
                 mix_head: nn.Linear, mix_kernel: nn.Linear,
                 mix_head_pattern: str='nk (nh hd)',
                 mix_kernel_pattern: str='nh hd nk'):
        super().__init__()
        self.dims = {'nk': n_kernels, 
                     'nh': n_heads, 
                     'hd': head_dim}
        
        self.mix_head   = mix_head
        self.mix_kernel = mix_kernel
        self.mix_head_pattern   = mix_head_pattern
        self.mix_kernel_pattern = mix_kernel_pattern
        
    def forward(self, x: torch.tensor) -> torch.tensor:
        # Assume x.shape is (..., n_kernels)
        x = self.mix_kernel(x)
        x = rearrange(x, f'b l {self.mix_kernel_pattern} -> b l {self.mix_head_pattern}',
                      **self.dims)
        x = self.mix_head(x)
        return x
