""" Utility nn components, in particular handling activations, initializations, and normalization layers """

from functools import partial
import math
from typing import ForwardRef
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from opt_einsum import contract


class Difference(nn.Module):
    def __init__(self, order, horizon):
        super().__init__()
        self.order = order
        self.horizon = horizon

    def forward(self, x):
        # x must be BLD
        lag = x.shape[1] - self.horizon
        for ix in range(order):
            x = x[:, 1:lag - ix, :] - x[:, :(lag - 1) - ix]
        return x

    
class InverseDifference(nn.Module):
    def __init__(self, order, horizon):
        super().__init__()
        self.order = order
        self.horizon = horizon
        
        assert self.order < 2  # , raise NotImplementedError  # Dumb

    def forward(self, x, first_val):
        # Assume first_val.shape and x.shape is B x L x D
        # first_val = x[:, -self.horizon-1, :][:, None, :]
        # first_val = x[:, 0, :][:, None, :]
        return first_val + torch.cumsum(x, dim=1)
    
    
    
class Normalization(nn.Module):
    def __init__(
        self,
        d,
        transposed=False, # Length dimension is -1 or -2
        _name_='layer',
        **kwargs
    ):
        super().__init__()
        self.transposed = transposed

        if _name_ == 'layer':
            self.channel = True # Normalize over channel dimension
            if self.transposed:
                self.norm = TransposedLN(d, **kwargs)
            else:
                self.norm = nn.LayerNorm(d, **kwargs)
        elif _name_ == 'instance':
            self.channel = False
            norm_args = {'affine': False, 'track_running_stats': False}
            norm_args.update(kwargs)
            self.norm = nn.InstanceNorm1d(d, **norm_args) # (True, True) performs very poorly
        elif _name_ == 'batch':
            self.channel = False
            norm_args = {'affine': True, 'track_running_stats': True}
            norm_args.update(kwargs)
            self.norm = nn.BatchNorm1d(d, **norm_args)
        elif _name_ == 'group':
            self.channel = False
            self.norm = nn.GroupNorm(1, d, *kwargs)
        elif _name_ == 'none':
            self.channel = True
            self.norm = nn.Identity()
        else: raise NotImplementedError

    def forward(self, x):
        # The cases of LayerNorm / no normalization are automatically handled in all cases
        # Instance/Batch Norm work automatically with transposed axes
        if self.channel or self.transposed:
            return self.norm(x)
        else:
            x = x.transpose(-1, -2)
            x = self.norm(x)
            x = x.transpose(-1, -2)
            return x
    

class TSNormalization(nn.Module):

    def __init__(self, method, horizon):
        super().__init__()
        self.method = method
        self.horizon = horizon

    def forward(self, x):
        # x must be BLD
        if self.method == 'mean':
            self.scale = x.abs()[:, :-self.horizon].mean(dim=1)[:, None, :]
            return x / self.scale
        elif self.method == 'last':
            self.scale = x.abs()[:, -self.horizon-1][:, None, :]
            return x / self.scale
        return x

    
class TSInverseNormalization(nn.Module):

    def __init__(self, method, normalizer):
        super().__init__()

        self.method = method
        self.normalizer = normalizer

    def forward(self, x):
        if self.method == 'mean' or self.method == 'last':
            return x * self.normalizer.scale
        return x

    
class ReversibleInstanceNorm1dInput(nn.Module):
    def __init__(self, d, transposed=False):
        super().__init__()
        # BLD if transposed is False, otherwise BDL
        self.transposed = transposed
        self.norm = nn.InstanceNorm1d(d, affine=True, track_running_stats=False)

    def forward(self, x):
        # Means, stds
        if not self.transposed:
            x = x.transpose(-1, -2)

        self.s, self.m = torch.std_mean(x, dim=-1, unbiased=False, keepdim=True)
        self.s += 1e-4

        x = (x - self.m) / self.s
        # x = self.norm.weight.unsqueeze(-1) * x + self.norm.bias.unsqueeze(-1)

        if not self.transposed:
            return x.transpose(-1, -2)
        return x
    
    
class DropoutNd(nn.Module):
    def __init__(self, p: float = 0.5, tie=True):
        """ tie: tie dropout mask across sequence lengths (Dropout1d/2d/3d)
        For some reason tie=False is dog slow, prob something wrong with torch.distribution
        """
        super().__init__()
        if p < 0 or p >= 1:
            raise ValueError("dropout probability has to be in [0, 1), " "but got {}".format(p))
        self.p = p
        self.tie = tie
        self.binomial = torch.distributions.binomial.Binomial(probs=1-self.p)

    def forward(self, X):
        """ X: (batch, dim, lengths...) """
        if self.training:
            # binomial = torch.distributions.binomial.Binomial(probs=1-self.p)
            mask_shape = X.shape[:2] + (1,)*(X.ndim-2) if self.tie else X.shape
            # mask = self.binomial.sample(mask_shape)
            mask = torch.rand(*mask_shape, device=X.device) < 1.-self.p
            return X * mask * (1.0/(1-self.p))
        return X    

    
def Activation(activation=None, size=None, dim=-1):
    if activation in [ None, 'id', 'identity', 'linear' ]:
        return nn.Identity()
    elif activation == 'tanh':
        return nn.Tanh()
    elif activation == 'relu':
        return nn.ReLU()
    elif activation == 'gelu':
        return nn.GELU()
    elif activation in ['swish', 'silu']:
        return nn.SiLU()
    elif activation == 'glu':
        return nn.GLU(dim=dim)
    elif activation == 'sigmoid':
        return nn.Sigmoid()
    elif activation == 'modrelu':
        return Modrelu(size)
    elif activation == 'sqrelu':
        return SquaredReLU()
    elif activation == 'ln':
        return TransposedLN(dim)
    else:
        raise NotImplementedError("hidden activation '{}' is not implemented".format(activation))

        
def get_initializer(name, activation=None):
    if activation in [ None, 'id', 'identity', 'linear', 'modrelu' ]:
        nonlinearity = 'linear'
    elif activation in ['relu', 'tanh', 'sigmoid']:
        nonlinearity = activation
    elif activation in ['gelu', 'swish', 'silu']:
        nonlinearity = 'relu' # Close to ReLU so approximate with ReLU's gain
    else:
        raise NotImplementedError(f"get_initializer: activation {activation} not supported")

    if name == 'uniform':
        initializer = partial(torch.nn.init.kaiming_uniform_, nonlinearity=nonlinearity)
    elif name == 'normal':
        initializer = partial(torch.nn.init.kaiming_normal_, nonlinearity=nonlinearity)
    elif name == 'xavier':
        initializer = torch.nn.init.xavier_normal_
    elif name == 'zero':
        initializer = partial(torch.nn.init.constant_, val=0)
    elif name == 'one':
        initializer = partial(torch.nn.init.constant_, val=1)
    else:
        raise NotImplementedError(f"get_initializer: initializer type {name} not supported")

    return initializer
    
    
def LinearActivation(
        d_input, d_output, bias=True,
        zero_bias_init=False,
        transposed=False,
        initializer=None,
        activation=None,
        activate=False, # Apply activation as part of this module
        weight_norm=False,
        **kwargs,
    ):
    """ Returns a linear nn.Module with control over axes order, initialization, and activation """

    # Construct core module
    # linear_cls = TransposedLinear if transposed else nn.Linear
    linear_cls = TransposedLinear if transposed else nn.Linear
    # linear_cls = Conv1d if transposed else nn.Linear
    if activation == 'glu': d_output *= 2
    linear = linear_cls(d_input, d_output, bias=bias, **kwargs)

    # Initialize weight
    if initializer is not None:
        get_initializer(initializer, activation)(linear.weight)

    # Initialize bias
    if bias and zero_bias_init:
        nn.init.zeros_(linear.bias)

    # Weight norm
    if weight_norm:
        linear = nn.utils.weight_norm(linear)

    if activate and activation is not None:
        print(activation)
        activation = Activation(activation, d_output, dim=1 if transposed else -1)
        linear = nn.Sequential(linear, activation)
    return linear

    
# Below here are standard wrapper classes to handle
# (1) Non-linearity
# (2) Integration with the Hippo Code base
class NonLinear(nn.Module):
    def __init__(self, h, channels, 
                 ln=False, # Extra normalization
                 transposed=True,
                 dropout=0.0, 
                 postact=None, # activation after FF
                 activation='gelu', # activation in between SS and FF
                 initializer=None, # initializer on FF
                 weight_norm=False, # weight normalization on FF
                 ):
            super().__init__()
            dropout_fn = DropoutNd # nn.Dropout2d bugged in PyTorch 1.11
            self.dropout = dropout_fn(dropout) if dropout > 0.0 else nn.Identity()
            #norm = Normalization(h*channels, transposed=transposed) if ln else nn.Identity()

            self.activation_fn = Activation(activation)

            self.output_linear = LinearActivation(
                h*channels,
                h,
                transposed=transposed, 
                initializer=initializer,
                activation=postact,
                activate=True,
                weight_norm=weight_norm,
            )
            #self.f = nn.Sequential(activation_fn, dropout, norm, output_linear)
            self.f = nn.Sequential(self.activation_fn, self.dropout, self.output_linear)
    def forward(self,x):  # Always (B H L)
        try:
            x = self.activation_fn(x)
        except:
            breakpoint()
        try:
            x = self.dropout(x)
        except:
            breakpoint()
        try:
            x = self.output_linear(x)
        except:
            breakpoint()
        return x
        # return self.f(x)
    
    
class TransposedLinear(nn.Module):
    """ Linear module on the second-to-last dimension
    Assumes shape (B, D, L), where L can be 1 or more axis
    """

    def __init__(self, d_input, d_output, bias=True):
        super().__init__()

        self.weight = nn.Parameter(torch.empty(d_output, d_input))
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) # nn.Linear default init
        # nn.init.kaiming_uniform_(self.weight, nonlinearity='linear') # should be equivalent

        if bias:
            self.bias = nn.Parameter(torch.empty(d_output))
            bound = 1 / math.sqrt(d_input)
            nn.init.uniform_(self.bias, -bound, bound)
            setattr(self.bias, "_optim", {"weight_decay": 0.0})
        else:
            self.bias = 0.0

    def forward(self, x):
        try:
            num_axis = len(x.shape[2:])  # num_axis in L, for broadcasting bias
        except:
            breakpoint()
        y = contract('b u ..., v u -> b v ...', x, self.weight) + self.bias.view(-1, *[1]*num_axis)
        return y
    
    
class FeedForward(nn.Module):
    def __init__(self, h, channels, 
                 d_hidden=2048, 
                 dropout = 0.0,
                 transposed=True,
                 activation='gelu', 
                 initializer=None, # initializer on FF
                 weight_norm=False, # weight normalization on FF
                ):
        super().__init__() 
        self.linear_1 = LinearActivation(h * channels, 
                                         d_hidden, 
                                         bias=True,
                                         zero_bias_init=False,
                                         transposed=transposed,
                                         initializer=initializer,
                                         activation=activation,
                                         activate=True,  # Apply activation after 1st layer
                                         weight_norm=weight_norm)
        self.dropout = nn.Dropout(dropout)
        self.linear_2 = LinearActivation(d_hidden,
                                         h * channels, 
                                         bias=True,
                                         zero_bias_init=False,
                                         transposed=transposed,
                                         initializer=initializer,
                                         activation=activation,
                                         activate=False,
                                         weight_norm=weight_norm)
        
    def forward(self, x):
        # Residual connection
        x = x + self.linear_2(self.dropout(self.linear_1(x)))
        return x