import torch; from torch import nn; import math
from torch.nn.parameter import Parameter
from torch.autograd import Variable

class NET(nn.Module):
    def __init__(self, input_dim, hidden_dim,  override = None, device=None, **kwargs):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.input_dim = input_dim
        self.override = override
        self.lesion = kwargs.get("lesion", None)
        self.device = device
        self.construction()
        self.to(device)

    def construction(self):             
        self.igates = torch.nn.Linear(1,1).to(self.device)
        self.hgates = torch.nn.Linear(1,1).to(self.device)
 
        if self.override == "RNN":
            self.igates = torch.nn.Linear(self.input_dim, self.hidden_dim).to(self.device)
            self.hgates = torch.nn.Linear(self.hidden_dim, self.hidden_dim).to(self.device)
        if self.override == "LSTM":
            self.igates = torch.nn.Linear(self.input_dim, 4 * self.hidden_dim).to(self.device)
            self.hgates = torch.nn.Linear(self.hidden_dim, 4 * self.hidden_dim).to(self.device)
            
            
        if self.override == 'FF':
            # no recurrent net, only feedforward to next layer 
            self.pavlovian =  torch.nn.Linear(1, self.hidden_dim).to(self.device)

        if self.override == 'FF_mem': 
            # no recurrent net, only feedforward with mem of last (N = hidden Dim) inputs
            self.pavlovian =  torch.nn.Linear(self.hidden_dim, self.hidden_dim).to(self.device)
            self.memory =  torch.zeros(1, self.hidden_dim).to(self.device)

        if self.override == 'inte_prob':
            self.DV = 0 

        # if self.override == 'integrator2':            
        #     self.pavlovian = torch.nn.Linear(self.input_dim, self.hidden_dim).to(self.device)
        #     self.inp2recur = torch.nn.Linear(self.input_dim, self.hidden_dim).to(self.device)
        #     self.recur = torch.nn.Linear(self.hidden_dim, self.hidden_dim).to(self.device)
        #     self.DV = 0 
            
        # if self.override == 'integrator3':            
        #     self.hidden = torch.nn.Linear(self.input_dim + 1, self.hidden_dim).to(self.device)
        #     self.DV = torch.zeros(1,1).to(self.device)
            
        # if self.override == 'integrator4':
        #     self.hiddenA = torch.nn.Linear(self.hidden_dim, self.hidden_dim).to(self.device)
        #     self.hiddenB = torch.nn.Linear(self.hidden_dim, self.hidden_dim).to(self.device)
        #     self.DV = torch.zeros(1, self.hidden_dim).to(self.device)
            
        self.weight_ih_l0 = self.igates.weight
        self.weight_hh_l0 = self.hgates.weight
        self.bias_ih_l0 = self.igates.bias
        self.bias_hh_l0 = self.hgates.bias
        self.open = torch.ones(1, self.hidden_dim, requires_grad = False, device = self.device)
        self.closed = torch.zeros(1, self.hidden_dim, requires_grad = False, device = self.device)
         
            
    def forward(self, inp, recur):
        inp, STM, LTM = [x.squeeze(0) for x in [inp, recur[0], recur[1]]]     

        if self.override == "RNN":
            out = torch.tanh((inp @ self.weight_ih_l0.T) + (STM @ self.weight_hh_l0.T) + self.bias_ih_l0 + self.bias_hh_l0)            
            LTM = STM = out.unsqueeze(0)
            f = i = c = o = out.detach()
            
        if self.override == "LSTM": 
            STM = (1 - int('STM' in self.lesion)) * STM
            gates = inp @ self.weight_ih_l0.T + STM @ self.weight_hh_l0.T + self.bias_ih_l0 + self.bias_hh_l0
            f, i, c, o = gates.chunk(4, 1)
            f = self.closed if 'LTM' in self.lesion else self.open if 'FORGET' in self.lesion else torch.sigmoid(f) 
            i = self.open if 'INPUT' in self.lesion else torch.sigmoid(i)
            o = self.open if 'OUTPUT' in self.lesion else torch.sigmoid(o) 

            c = torch.tanh(c)
            LTM = i * c + LTM * f  
            STM = torch.tanh(LTM) * o
            LTM = LTM.unsqueeze(0)
            STM = STM.unsqueeze(0)
        
        
        
        
        """ integrators for testing derivative-input RL theory """
        if self.override == 'FF':
            STM = self.pavlovian(inp[:,0]).unsqueeze(0).unsqueeze(0)
            LTM, f, i, c, o = self.empty_layers()

        if self.override == 'FF_mem':
            self.memory = torch.roll(self.memory, -1, 1)
            self.memory[:, -1] = inp[:, 0]     
            STM = self.pavlovian(self.memory).unsqueeze(0)
            LTM, f, i, c, o = self.empty_layers()

         
        if self.override == 'inte_prob':
            update = self.pavlovian(inp[:,0])
            self.DV = self.DV + update
            STM = self.DV * self.open.unsqueeze(0)
            LTM, f, i, c, o = self.empty_layers()

        # if self.override == 'integrator2':
        #     if inp[:, 0] == 0:
        #     # if inp[:, 0] == -1:
        #         self.DV = 0
        #     else:
        #         self.DV += 1 
                
        #     pavlov_bias = self.pavlovian(inp)
        #     LTM =  torch.sigmoid(self.recur(LTM) + self.inp2recur(inp))
        #     STM =  LTM * self.DV + pavlov_bias 
            
        #     LTM = LTM.unsqueeze(0)
        #     STM = STM.unsqueeze(0)
        #     f = (LTM * self.DV).detach().squeeze(0)
        #     i =  self.inp2recur(inp).detach() 
        #     c = pavlov_bias.detach() 
        #     o = self.recur(LTM).detach().squeeze(0)
            
        # if self.override == 'integrator3':
        #     # if inp[:, 0] == 0:
        #     if inp[:, 0] == -1:
        #         self.DV[0] = 0
        #     else:
        #         self.DV[0] = self.DV[0] + 1 
            
        #     STM = self.hidden(torch.cat((inp, self.DV), 1))
        #     f=i=c=o=LTM =STM.detach()            
        #     LTM = LTM.unsqueeze(0)
        #     STM = STM.unsqueeze(0)
            
        # if self.override == 'integrator4':
        #     self.DV = torch.roll(self.DV, -1, 1)
        #     self.DV[:, -1] = inp[:, 0]       
        #     f =  torch.sigmoid(self.hiddenA(self.DV))
        #     i = torch.tanh(self.hiddenB(self.DV))
        #     STM = f*i

        #     # STM = torch.tanh(self.hiddenA(self.DV))
        #     # f=i = STM.detach()
            
        #     c=o=LTM=STM.detach()            
        #     LTM = LTM.unsqueeze(0)
        #     STM = STM.unsqueeze(0)
            
            
        return STM, (STM, LTM), torch.cat((f, i, c, o), 0)
    
    def empty_layers(self):
        return self.closed.unsqueeze(0), self.open, self.open, self.open, self.open
            
