import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils import weight_norm
from torchvision import transforms, utils
from torchvision.ops import MLP



class MNODE_decoder(nn.Module):
    def __init__(self, DAG, _unused1, _unused2, input_size=5, \
                 output_ind=0, mlp_size=16, num_hidden_layers=2, \
                 activation=nn.ReLU, dropout=0.0):
        super().__init__()
        self.dag=DAG
        self.state_size = len(DAG)
        self.output_ind = output_ind
        self.mlp_inputs = [DAG[i][-1][0] for i in range(self.state_size)]
        mlp_struct=[mlp_size]*num_hidden_layers
        mlp_struct.append(1)
        mlp_list=[MLP(self.mlp_inputs[i],mlp_struct,\
                      activation_layer=activation, inplace=None, dropout=dropout) for i in range(self.state_size)]
        self.f = nn.ModuleList(mlp_list)
    def forward(self, inputs, hidden):
        d_hidden=[]
        for i in range(len(self.f)):
            mlp_in=torch.concat([hidden[:,self.dag[i][0]],inputs[:,self.dag[i][1]]],axis=-1)
            d_hidden.append(self.f[i](mlp_in))
        new_hidden=hidden+torch.concat(d_hidden,axis=-1)
        return new_hidden[:,self.output_ind:self.output_ind+1], new_hidden 
    def initHidden(self,batch_size):
        return torch.zeros(batch_size, self.state_size)

class MNODE_decoder_masked(MNODE_decoder):
    def __init__(self, DAG, edge_map, _unused2, input_size, output_ind, mlp_size,\
                 num_hidden_layers, activation, dropout):
        super().__init__(DAG, edge_map, _unused2, input_size, output_ind, mlp_size,\
                 num_hidden_layers, activation, dropout)
        self.edge_map=edge_map
    def forward(self, inputs, hidden, ew):
        d_hidden=[]
        for i in range(len(self.f)):
            mlp_in=torch.concat([hidden[:,self.dag[i][0]],inputs[:,self.dag[i][1]]],axis=-1)
            edge_weights=torch.concat([ew[self.edge_map[i][0]],ew[self.edge_map[i][1]]],axis=-1)
            scaled_mlp_in=torch.multiply(mlp_in,torch.abs(edge_weights))
            d_hidden.append(self.f[i](scaled_mlp_in))
        new_hidden=hidden+torch.concat(d_hidden,axis=-1)
        return new_hidden[:,self.output_ind:self.output_ind+1], new_hidden

        
class MNODE_decoder_GL(MNODE_decoder):
    def __init__(self, DAG, edge_map, init_ew, input_size, output_ind, mlp_size,\
                 num_hidden_layers, activation, dropout):
        super().__init__(DAG, edge_map, init_ew, input_size, output_ind, mlp_size,\
                 num_hidden_layers, activation, dropout)
        self.edge_map=edge_map
        if init_ew==None:
            init_ew=torch.ones(np.sum(self.mlp_inputs))*1e-3
        self.ew = nn.Parameter(init_ew)
    def forward(self, inputs, hidden):
        d_hidden=[]
        for i in range(len(self.f)):
            mlp_in=torch.concat([hidden[:,self.dag[i][0]],inputs[:,self.dag[i][1]]],axis=-1)
            edge_weights=torch.concat([self.ew[self.edge_map[i][0]],self.ew[self.edge_map[i][1]]],axis=-1)
            scaled_mlp_in=torch.multiply(mlp_in,torch.abs(edge_weights))
            d_hidden.append(self.f[i](scaled_mlp_in))
        new_hidden=hidden+torch.concat(d_hidden,axis=-1)
        return new_hidden[:,self.output_ind:self.output_ind+1], new_hidden 
    def return_edge_weights(self):
        return torch.abs(self.ew)
    def freeze_ew(self):
        self.ew.requires_grad=False
    def set_ew(self, weights):
        self.ew=weights

class MNODE_NR(nn.Module):
    def __init__(self, DAG, edge_map, init_ew=None, hyper_params=None):
        super().__init__()
        self.encoder=nn.LSTM(input_size=hyper_params['input_size'],\
                             hidden_size=len(DAG),\
                             num_layers=2,\
                             batch_first=True)
        self.decoder=MNODE_decoder(DAG, edge_map, init_ew,\
                                   input_size=hyper_params['input_size'],\
                                   output_ind=0, mlp_size=hyper_params['mlp_size'],\
                                   num_hidden_layers=hyper_params['num_hidden_layers'],\
                                   activation=hyper_params['activation'],\
                                   dropout=hyper_params['dropout'])
        self.dag=DAG
        self.edge_map=edge_map
        self.hp=hyper_params
    def forward(self,past,s,x):
        _, (h0,_)=self.encoder(past)
        h0=h0[-1]
        h0=h0[:,1:]
        hidden=torch.concat([s,h0],axis=-1)
        pred, hidden = self.decoder(x[:,0],hidden)
        pred = torch.unsqueeze(pred,axis=1)
        for j in range(1,x.shape[1]):
            new_pred, hidden = self.decoder(x[:,j],hidden)
            new_pred = torch.unsqueeze(new_pred,axis=1)
            pred = torch.concat([pred,new_pred],axis=1)
        return pred
        
    def return_dag(self):
        return self.dag
    def return_edge_map(self):
        return self.edge_map

                
class MNODE_NS(MNODE_NR):
    def __init__(self, DAG, edge_map, init_ew=None, hyper_params=None):
        super().__init__(DAG, edge_map, init_ew, hyper_params)
        self.ew_size=sum([node[-1][0] for node in DAG])
        self.z=nn.Parameter(torch.rand(size=(self.ew_size,)))
        self.decoder=MNODE_decoder_masked(DAG, edge_map, init_ew,\
                                   input_size=hyper_params['input_size'],\
                                   output_ind=0, mlp_size=hyper_params['mlp_size'],\
                                   num_hidden_layers=hyper_params['num_hidden_layers'],\
                                   activation=hyper_params['activation'],\
                                   dropout=hyper_params['dropout'])
        self.K=hyper_params['k']
        self.eps=nn.Parameter(torch.rand(size=(self.K,self.ew_size)))
    def forward(self,past,s,x):
        pi=nn.functional.softmax(self.z)
        ew=torch.zeros_like(self.z)
        for i in range(self.K):
            w=torch.exp((torch.log(pi)-torch.log(-torch.log(self.eps[i])))*10)
            w=torch.round(w/torch.sum(w),decimals=2)
            ew=torch.logical_or(ew.bool(),w.bool()).double()

        _, (h0,_)=self.encoder(past)
        h0=h0[-1]
        h0=h0[:,1:]
        hidden=torch.concat([s,h0],axis=-1)
        pred, hidden = self.decoder(x[:,0],hidden,ew)
        pred = torch.unsqueeze(pred,axis=1)
        for j in range(1,x.shape[1]):
            new_pred, hidden = self.decoder(x[:,j],hidden,ew)
            new_pred = torch.unsqueeze(new_pred,axis=1)
            pred = torch.concat([pred,new_pred],axis=1)
        return pred
    
    
class MNODE_GL(MNODE_NR):
    def __init__(self, DAG, edge_map, init_ew=None, hyper_params=None):
        super().__init__(DAG, edge_map, init_ew, hyper_params)
        self.decoder=MNODE_decoder_GL(DAG, edge_map, init_ew,\
                                   input_size=hyper_params['input_size'],\
                                   output_ind=0, mlp_size=hyper_params['mlp_size'],\
                                   num_hidden_layers=hyper_params['num_hidden_layers'],\
                                   activation=hyper_params['activation'],\
                                   dropout=hyper_params['dropout'])

    def return_edge_norms(self):
        return torch.sum(self.decoder.return_edge_weights())
    def return_edge_weights(self):
        return self.decoder.return_edge_weights()
    def freeze_ew(self):
        self.decoder.freeze_ew()
    def set_ew(self,weights):
        self.decoder.set_ew(weights)
    def init_ew(self):
        return torch.ones(np.sum([self.dag[i][-1][0] for i in range(len(self.dag))]))
        

class MNODE_EN(MNODE_GL):
    def __init__(self, DAG, edge_map, init_ew=None, hyper_params=None):
        super().__init__(DAG, edge_map, init_ew, hyper_params)

class MNODE_EGL(MNODE_GL):
    def __init__(self, DAG, edge_map, init_ew=None, hyper_params=None):
        super().__init__(DAG, edge_map, init_ew, hyper_params)
        
class MNODE_GD(MNODE_GL):
    def __init__(self, DAG, edge_map, init_ew=None, hyper_params=None):
        super().__init__(DAG, edge_map, init_ew, hyper_params)

class MNODE_RD(MNODE_GD):
    def __init__(self, DAG, edge_map, init_ew=None, hyper_params=None):
        super().__init__(DAG, edge_map, init_ew, hyper_params)
    