import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import opt_einsum as oe

from torch.nn.utils import weight_norm
from torchvision import transforms, utils
from torchvision.ops import MLP
from einops import rearrange, repeat


class LSTM_NR(nn.Module):
    def __init__(self,hyper_params,output_size=1):
        super().__init__()
        self.lstm_en=nn.LSTM(input_size=hyper_params['input_size']+1,\
                             hidden_size=hyper_params['latent_size'],\
                             num_layers=hyper_params['num_layers'],\
                             batch_first=True,\
                             dropout=hyper_params['dropout'])
        self.lstm_de=nn.LSTM(input_size=hyper_params['input_size'],\
                             hidden_size=hyper_params['latent_size'],\
                             num_layers=hyper_params['num_layers'],\
                             batch_first=True,\
                             dropout=hyper_params['dropout'])
        self.proj=nn.Linear(hyper_params['latent_size'],output_size)

    def forward(self,past,s,x):
        lstm_out1, (h0,c0)=self.lstm_en(past)
        lstm_out2, _=self.lstm_de(x[:,:,:-1],(h0,c0))
        return self.proj(lstm_out2)


class BNODE(nn.Module):
    def __init__(self, input_size, output_size, mlp_size, hidden_size,\
                 activation, dropout, num_hidden_layer):
        super().__init__()
        self.hidden_size = hidden_size
        self.input_size = input_size
        mlp_struct1=[mlp_size]*num_hidden_layer
        mlp_struct1.append(hidden_size)
        mlp_struct2=[mlp_size]*num_hidden_layer
        mlp_struct2.append(output_size)
        self.f = MLP(input_size+hidden_size,mlp_struct1,\
                     activation_layer=activation, inplace=None,dropout=dropout)
        self.h = MLP(input_size+hidden_size,mlp_struct2,\
                     activation_layer=activation, inplace=None,dropout=dropout)
    def forward(self, x, hidden):
        #hidden=[i1,w1,i2,w2,i3,w3
        dh = self.f(torch.concat([x,hidden],axis=-1))
        new_hidden = hidden + dh
        return self.h(torch.concat([new_hidden,x],axis=-1)), new_hidden
    def initHidden(self,batch_size):
        return torch.zeros(batch_size, self.hidden_size)


class BNODE_NR(nn.Module):
    def __init__(self,hyper_params,output_size=1):
        super().__init__()
        self.lstm=nn.LSTM(input_size=hyper_params['input_size']+1,\
                          hidden_size=hyper_params['latent_size'],\
                          num_layers=2,\
                          batch_first=True)
        self.bnode_cell=BNODE(input_size=hyper_params['input_size'],\
                                output_size=output_size,\
                                mlp_size=hyper_params['mlp_size'],\
                                hidden_size=hyper_params['latent_size'],\
                                activation=hyper_params['activation'],\
                                dropout=hyper_params['dropout'],\
                                num_hidden_layer=hyper_params['num_hidden_layers'])

    def forward(self,past,s,x):
        #past is N*L*5
        lstm_out, (h0,_)=self.lstm(past)
        h0=h0[1]
        h0=h0[:,1:]
        hidden=torch.concat([s,h0],axis=-1)
        pred, hidden = self.bnode_cell(x[:,0,:-1],hidden)
        pred = torch.unsqueeze(pred,axis=1)
        for j in range(1,x.shape[1]):
            new_pred, hidden = self.bnode_cell(x[:,j,:-1],hidden)
            new_pred = torch.unsqueeze(new_pred,axis=1)
            pred = torch.concat([pred,new_pred],axis=1)
        return pred




_c2r = torch.view_as_real
_r2c = torch.view_as_complex

class S4DKernel(nn.Module):
    """Wrapper around SSKernelDiag that generates the diagonal SSM parameters
    """

    def __init__(self, d_model, N=64, dt_min=0.001, dt_max=0.1, lr=None):
        super().__init__()
        # Generate dt
        H = d_model
        log_dt = torch.rand(H) * (
            math.log(dt_max) - math.log(dt_min)
        ) + math.log(dt_min)

        C = torch.randn(H, N // 2, dtype=torch.cfloat)
        self.C = nn.Parameter(_c2r(C))
        self.register("log_dt", log_dt, lr)

        log_A_real = torch.log(0.5 * torch.ones(H, N//2))
        A_imag = math.pi * repeat(torch.arange(N//2), 'n -> h n', h=H)
        self.register("log_A_real", log_A_real, lr)
        self.register("A_imag", A_imag, lr)

    def forward(self, L):
        """
        returns: (..., c, L) where c is number of channels (default 1)
        """

        # Materialize parameters
        dt = torch.exp(self.log_dt) # (H)
        C = _r2c(self.C) # (H N)
        A = -torch.exp(self.log_A_real) + 1j * self.A_imag # (H N)

        # Vandermonde multiplication
        dtA = A * dt.unsqueeze(-1)  # (H N)
        K = dtA.unsqueeze(-1) * torch.arange(L, device=A.device) # (H N L)
        C = C * (torch.exp(dtA)-1.) / A
        K = 2 * torch.einsum('hn, hnl -> hl', C, torch.exp(K)).real

        return K

    def register(self, name, tensor, lr=None):
        """Register a tensor with a configurable learning rate and 0 weight decay"""

        if lr == 0.0:
            self.register_buffer(name, tensor)
        else:
            self.register_parameter(name, nn.Parameter(tensor))

            optim = {"weight_decay": 0.0}
            if lr is not None: optim["lr"] = lr
            setattr(getattr(self, name), "_optim", optim)


class S4D(nn.Module):

    def __init__(self, d_model, d_state=64, dropout=0.0, gating=True, **kernel_args):
        super().__init__()

        self.h = d_model
        self.n = d_state
        self.d_output = self.h

        self.D = nn.Parameter(torch.randn(self.h))

        # SSM Kernel
        self.kernel = S4DKernel(self.h, N=self.n, **kernel_args)

        # Pointwise
        self.activation = nn.GELU()
        dropout_fn = nn.Dropout2d # NOTE: bugged in PyTorch 1.11
        # dropout_fn = DropoutNd
        self.dropout = dropout_fn(dropout) if dropout > 0.0 else nn.Identity()

        # position-wise output transform to mix features
        if gating:
            self.output_linear = nn.Sequential(
                nn.Conv1d(d_model, 2*d_model, 1),
                nn.GLU(dim=-2),
            )
        else:
            self.output_linear = nn.Sequential(
                nn.Conv1d(d_model, d_model, 1),
                nn.GELU()
            )

    def forward(self, u, **kwargs): # absorbs return_output and transformer src mask
        """ Input and output shape (B, H, L) """
        L = u.size(-1)

        # Compute SSM Kernel
        k = self.kernel(L=L) # (H L)

        # Convolution
        k_f = torch.fft.rfft(k, n=2*L) # (H L)
        u_f = torch.fft.rfft(u, n=2*L) # (B H L)
        y = torch.fft.irfft(u_f*k_f, n=2*L)[..., :L] # (B H L)

        # Compute D term in state space equation - essentially a skip connection
        y = y + u * self.D.unsqueeze(-1)

        y = self.dropout(self.activation(y))
        # y = self.activation(y)
        y = self.output_linear(y)
        return y # Return a dummy state to satisfy this repo's interface, but this can be modified
    
class S4D_NR(nn.Module):
    def __init__(self,hyper_params,feature_size=5,window=12):
        super().__init__()
        self.s4d=S4D(d_model=hyper_params['d_model'],\
                     d_state=hyper_params['d_state'],\
                     dropout=hyper_params['dropout'])
        self.linear1=nn.Linear(hyper_params['input_size'],\
                               hyper_params['d_model'])
        self.linear2=nn.Linear(hyper_params['d_model'],1)
        self.window=window
    
    def forward(self, past,s,x):
        x=torch.nn.functional.pad(x,(1,0), "constant", 0)
        x[:,0,0:1]=s
        seq_in=torch.concat([past,x[:,:,:-1]],dim=1)
        seq_in=torch.transpose(self.linear1(seq_in),1,2)
        output=self.s4d(seq_in)
        output=self.linear2(torch.transpose(output,1,2))
        return output[:,-self.window:]

class Chomp1d(nn.Module):
    def __init__(self, chomp_size):
        super(Chomp1d, self).__init__()
        self.chomp_size = chomp_size

    def forward(self, x):
        return x[:, :, :-self.chomp_size].contiguous()


class TemporalBlock(nn.Module):
    def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout=0.2):
        super(TemporalBlock, self).__init__()
        self.conv1 = weight_norm(nn.Conv1d(n_inputs, n_outputs, kernel_size,
                                           stride=stride, padding=padding, dilation=dilation))
        self.chomp1 = Chomp1d(padding)
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout)

        self.conv2 = weight_norm(nn.Conv1d(n_outputs, n_outputs, kernel_size,
                                           stride=stride, padding=padding, dilation=dilation))
        self.chomp2 = Chomp1d(padding)
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout)

        self.net = nn.Sequential(self.conv1, self.chomp1, self.relu1, self.dropout1,
                                 self.conv2, self.chomp2, self.relu2, self.dropout2)
        self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None
        self.relu = nn.ReLU()
        self.init_weights()

    def init_weights(self):
        self.conv1.weight.data.normal_(0, 0.01)
        self.conv2.weight.data.normal_(0, 0.01)
        if self.downsample is not None:
            self.downsample.weight.data.normal_(0, 0.01)

    def forward(self, x):
        out = self.net(x)
        res = x if self.downsample is None else self.downsample(x)
        return self.relu(out + res)


class TemporalConvNet(nn.Module):
    def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.2):
        super(TemporalConvNet, self).__init__()
        layers = []
        num_levels = len(num_channels)
        for i in range(num_levels):
            dilation_size = 2 ** i
            in_channels = num_inputs if i == 0 else num_channels[i-1]
            out_channels = num_channels[i]
            layers += [TemporalBlock(in_channels, out_channels, kernel_size, stride=1,\
                                     dilation=dilation_size,\
                                     padding=(kernel_size-1) * dilation_size, dropout=dropout)]

        self.network = nn.Sequential(*layers)

    def forward(self, x):
        return self.network(x)
    
    
class TCN(nn.Module):
    def __init__(self, input_size, output_size, num_channels, kernel_size, dropout):
        super(TCN, self).__init__()
        self.tcn = TemporalConvNet(input_size, num_channels, kernel_size=kernel_size, dropout=dropout)
        self.linear = nn.Linear(num_channels[-1], output_size)
        self.init_weights()

    def init_weights(self):
        self.linear.weight.data.normal_(0, 0.01)

    def forward(self, x):
        y1 = self.tcn(x)
        return self.linear(y1[:, :, -1])
    
class TCN_NR(nn.Module):
    def __init__(self, hyper_params, output_size=1):
        super().__init__()
        self.TCN = TCN(input_size=hyper_params['input_size'],\
                       output_size=12, \
                       num_channels=[hyper_params['conv_size']]*hyper_params['num_layers'],\
                       kernel_size=hyper_params['kernel_size'],\
                       dropout=hyper_params['dropout'])

    def forward(self, past,s,x):
        
        x=torch.nn.functional.pad(x,(1,0), "constant", 0)
        x[:,0,0:1]=s
        seq_in=torch.concat([past,x[:,:,:-1]],dim=1)
        seq_in=torch.permute(seq_in,(0,2,1))
        output=self.TCN(seq_in)
        output=torch.unsqueeze(output,dim=-1)
        return output

class PositionalEncoding(nn.Module):

    def __init__(self, d_model, dropout=0.1, max_len=5000,device=None):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout).to(device)
        position = torch.arange(max_len).unsqueeze(1).to(device)
        div_term = torch.exp(torch.arange(0, d_model, 2).to(device) * (-math.log(10000.0) / d_model))
        self.pe = torch.zeros(max_len, 1, d_model).to(device)
        self.pe[:, 0, 0::2] = torch.sin(position * div_term)
        self.pe[:, 0, 1::2] = torch.cos(position * div_term)
        

    def forward(self, x):
        """
        Arguments:
            x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)
    
class Transformer_TS(nn.Module):
    def __init__(self, hyper_params, output_size=1, device=None):
        super().__init__()
        self.trans=nn.Transformer(d_model=hyper_params['d_model'],\
                                  nhead=hyper_params['nhead'],\
                                  num_encoder_layers=hyper_params['num_encoder_layers'],\
                                  num_decoder_layers=hyper_params['num_decoder_layers'],\
                                  dim_feedforward=hyper_params['dim_feedforward'],\
                                  dropout=hyper_params['dropout'],\
                                  batch_first=True)
        self.pos_enc=PositionalEncoding(hyper_params['d_model'],hyper_params['dropout'],device=device)
        self.src_encoder=nn.Linear(hyper_params['input_size'],hyper_params['d_model'])
        self.tgt_encoder=nn.Linear(output_size,hyper_params['d_model'])
        self.decoder=nn.Linear(hyper_params['d_model'],output_size)
        
        
    def forward(self,src,tgt,mask=None):
        src=self.src_encoder(src)
        tgt=self.tgt_encoder(tgt)
        src=self.pos_enc(src)
        if (mask!=None):
            out=self.trans(src,tgt,tgt_mask=mask)
        else:
            out=self.trans(src,tgt)
        return self.decoder(out)