import torch
from torch import nn
from typing import List
import pdb
from einops import rearrange

activation_functions = {
    'ReLU': nn.ReLU(),
    'Sigmoid': nn.Sigmoid(),
    'Tanh': nn.Tanh(),
    'LeakyReLU': nn.LeakyReLU(negative_slope=0.01),  
    'ELU': nn.ELU(alpha=1.0),  
    'PReLU': nn.PReLU(num_parameters=1, init=0.25),
    'Mish' : nn.Mish(),
    'SELU' : nn.SELU(),
    'Hardshrink': nn.Hardshrink(),

}

class MLP(torch.nn.Module):
    def __init__(self, data_dim, w=64, time_varying=False, conditional_dim=1):
        super().__init__()
        self.time_varying     = time_varying
        self.conditional_dim  = conditional_dim
        self.data_dim         = data_dim
        
        out_dim = (conditional_dim+1) * data_dim
        self.net = torch.nn.Sequential(
            torch.nn.Linear(data_dim + (1 if time_varying else 0) + conditional_dim, w),
            torch.nn.SELU(),
            torch.nn.Linear(w, w),
            torch.nn.SELU(),
            torch.nn.Linear(w, w),
            torch.nn.SELU(),
            torch.nn.Linear(w, out_dim),
        )

    def forward(self, x):
        return self.net(x).reshape(-1,self.data_dim,1+self.conditional_dim)

class DeepMLP(torch.nn.Module):
    def __init__(self, data_dim, w=64, time_varying=False, conditional_dim=1, activation='SELU', depth=5):
        super().__init__()
        self.time_varying = time_varying
        self.conditional_dim = conditional_dim
        self.data_dim = data_dim
        self.activation = activation_functions[activation]
        self.out_dim = (conditional_dim + 1) * data_dim
        self.depth = depth
        in_dim = data_dim + (1 if time_varying else 0) + conditional_dim

        self.layers = torch.nn.ModuleList()
        layer = torch.nn.Linear(in_dim, w)
        self.layers.append(layer)        
        for i in range(1,depth):
            layer = torch.nn.Linear(w * i, w * (i+1))
            self.layers.append(layer)

        self.final_layer = torch.nn.Linear(w * depth, self.out_dim)

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
            x = self.activation(x)

        x = self.final_layer(x)
        return x



class DenseNet(torch.nn.Module):
    def __init__(self, data_dim, w=64, time_varying=False, conditional_dim=1,
    activation='SELU', **kwargs):
        super().__init__()
        self.time_varying     = time_varying
        self.conditional_dim  = conditional_dim
        self.data_dim         = data_dim
        self.activation = activation_functions[activation]
        self.out_dim = (conditional_dim+1) * data_dim
        in_dim = data_dim + (1 if time_varying else 0) + conditional_dim
        self.l1 = torch.nn.Linear(in_dim, w)
        self.l2 = torch.nn.Linear(in_dim + w * 1, w)
        self.l3 = torch.nn.Linear(in_dim + w * 2, w)
        self.l4 = torch.nn.Linear(in_dim + w * 3, w)
        self.l5 = torch.nn.Linear(in_dim + w * 4, self.out_dim)

    def forward(self, x):
        x = torch.cat((x, self.l1(x)), dim=-1)
        x = self.activation(x)
        x = torch.cat((x, self.l2(x)), dim=-1)
        x = self.activation(x)
        x = torch.cat((x, self.l3(x)), dim=-1)
        x = self.activation(x)
        x = torch.cat((x, self.l4(x)), dim=-1)
        x = self.activation(x)
        x = self.l5(x)        
        return x


class DeepDenseNet(torch.nn.Module):
    def __init__(self, data_dim, w=64, time_varying=False, conditional_dim=1, activation='SELU', depth=5):
        super().__init__()
        self.time_varying = time_varying
        self.conditional_dim = conditional_dim
        self.data_dim = data_dim
        self.activation = activation_functions[activation]
        self.out_dim = (conditional_dim + 1) * data_dim
        in_dim = data_dim + (1 if time_varying else 0) + conditional_dim

        self.layers = torch.nn.ModuleList()
        for i in range(depth):
            layer = torch.nn.Linear(in_dim + w * i, w)
            self.layers.append(layer)

        self.final_layer = torch.nn.Linear(in_dim + w * depth, self.out_dim)

    def forward(self, x):
        for layer in self.layers:
            x = torch.cat((x, layer(x)), dim=-1)
            x = self.activation(x)
        x = self.final_layer(x)
        return x

class TCseparateDeepDenseNet(torch.nn.Module):
    def __init__(self, data_dim, w=64, time_varying=True, conditional_dim=1, activation='SELU', depth=5):
        super().__init__()
        self.time_varying = time_varying
        self.conditional_dim = conditional_dim
        self.TCseparate = True
        self.data_dim = data_dim
        self.activation = activation_functions[activation]
        # self.out_dim = (conditional_dim + 1) * data_dim
        self.out_dim0 = data_dim
        self.out_dim1 = conditional_dim * data_dim
        in_dim = data_dim + (1 if time_varying else 0) + conditional_dim

        self.layers0 = torch.nn.ModuleList()
        self.layers1 = torch.nn.ModuleList()
        for i in range(depth):
            layer0 = torch.nn.Linear(in_dim + w * i, w)
            layer1 = torch.nn.Linear(in_dim + w * i, w)
            self.layers0.append(layer0)
            self.layers1.append(layer1)
        self.final_layer0 = torch.nn.Linear(in_dim + w * depth, self.out_dim0)
        self.final_layer1 = torch.nn.Linear(in_dim + w * depth, self.out_dim1)

    def forward(self, x):
        x0 = x
        for layer in self.layers0:
            x0 = torch.cat((x0, layer(x0)), dim=-1)
            x0 = self.activation(x0)
        x0 = self.final_layer0(x0).unsqueeze(2)
        
        x1 = x
        for layer in self.layers1:
            x1 = torch.cat((x1, layer(x1)), dim=-1)
            x1 = self.activation(x1)
        x1 = self.final_layer1(x1).reshape(-1,self.data_dim,self.conditional_dim)

        x_pair = torch.cat((x0, x1), dim=2)
        return x_pair.reshape(x.shape[0], -1)


        
class FlowNet(torch.nn.Module):
    def __init__(self, mynet, config=None):
        super().__init__()
        self.net = mynet
        self.data_dim = mynet.data_dim
        self.conditional_dim = mynet.conditional_dim
        self.rival = config.rival
        self.config = config
        if self.rival in ['guided', 'lfm']:
            #output is a vector, not matrix
            output_dim =  self.data_dim
            self.addendum_layer = torch.nn.Linear(self.net.out_dim, output_dim)
    
    def forward(self, x):
        if self.rival in ['guided', 'lfm']:
            outval = self.addendum_layer(self.net(x))
            return outval.reshape(-1, self.data_dim, 1)
        else:
            return self.net(x).view(-1,self.data_dim,1+self.conditional_dim)


class DeepMLP4EFM(nn.Sequential):
    def __init__(
        self,
        data_dim: int,
        conditional_dim: int,
        activation: str = 'ELU',
        time_varying: bool = False,
        hidden_features: List[int] = [128]*6,
    ):
        layers = []
        self.data_dim = data_dim
        self.conditional_dim = conditional_dim 
        x = torch.cat((x, self.l1(x)), dim=-1)
        x = torch.nn.functional.selu(x)
        x = torch.cat((x, self.l2(x)), dim=-1)
        x = torch.nn.functional.selu(x)
        x = torch.cat((x, self.l3(x)), dim=-1)
        x = torch.nn.functional.selu(x)
        x = torch.cat((x, self.l4(x)), dim=-1)
        x = torch.nn.functional.selu(x)
        x = self.l5(x)
        self.in_features = data_dim + (1 if time_varying else 0) + conditional_dim
        self.out_features = (self.conditional_dim+1) * self.data_dim
        self.hidden_features = hidden_features

        for a, b in zip(
            (self.in_features, *hidden_features),
            (*hidden_features, self.out_features),
        ):
            layers.extend([nn.Linear(a, b), activation_functions[activation]])

        super().__init__(*layers[:-1])



class GradModel(torch.nn.Module):
    def __init__(self, action):
        super().__init__()
        self.action = action

    def forward(self, x):
        x = x.requires_grad_(True)
        grad = torch.autograd.grad(torch.sum(self.action(x)), x, create_graph=True)[0]
        return grad[:, :-1]