import torch; from torch import nn; import math
from torch.nn.parameter import Parameter
from torch.autograd import Variable

class NET(nn.Module):
    def __init__(self, input_dim, hidden_dim, lesion = None, override = None, device=None, leak = None):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.input_dim = input_dim
        self.override = override
        self.lesion = lesion
        self.leak = leak
        self.device = device
        self.construction()
        self.to(device)

    def construction(self):             
        self.igates = torch.nn.Linear(1,1).to(self.device)
        self.hgates = torch.nn.Linear(1,1).to(self.device)
 
        if self.override == "RNN":
            self.igates = torch.nn.Linear(self.input_dim, self.hidden_dim).to(self.device)
            self.hgates = torch.nn.Linear(self.hidden_dim, self.hidden_dim).to(self.device)
        if self.override == "LSTM":
            self.igates = torch.nn.Linear(self.input_dim, 4 * self.hidden_dim).to(self.device)
            self.hgates = torch.nn.Linear(self.hidden_dim, 4 * self.hidden_dim).to(self.device)
            
        if self.override == 'FF':
            # no recurrent net, only feedforward to next layer 
            self.pavlovian =  torch.nn.Linear(1, self.hidden_dim).to(self.device)

        if self.override == 'FF_mem': 
            # no recurrent net, only feedforward with mem of last (N = hidden Dim) inputs
            self.pavlovian =  torch.nn.Linear(self.hidden_dim, self.hidden_dim).to(self.device)
            self.memory =  torch.zeros(1, self.hidden_dim).to(self.device)

        self.weight_ih_l0 = self.igates.weight
        self.weight_hh_l0 = self.hgates.weight
        self.bias_ih_l0 = self.igates.bias
        self.bias_hh_l0 = self.hgates.bias
        self.open = torch.ones(1, self.hidden_dim, requires_grad = False, device = self.device)
        self.closed = torch.zeros(1, self.hidden_dim, requires_grad = False, device = self.device)
            
    def forward(self, inp, recur):
        inp, STM, LTM = [x.squeeze(0) for x in [inp, recur[0], recur[1]]]     
        
        if self.override == "RNN":
            out = torch.tanh((inp @ self.weight_ih_l0.T) + (STM @ self.weight_hh_l0.T) + self.bias_ih_l0 + self.bias_hh_l0)            
            LTM = STM = out.unsqueeze(0)
            f = i = c = o = out.detach()
            
        if self.override == "LSTM": 
            STM = (1 - int('STM' in self.lesion)) * STM
            gates = inp @ self.weight_ih_l0.T + STM @ self.weight_hh_l0.T + self.bias_ih_l0 + self.bias_hh_l0
            f, i, c, o = gates.chunk(4, 1)
            f = self.closed if 'LTM' in self.lesion else self.open if 'FORGET' in self.lesion else torch.sigmoid(f) 
            i = self.open if 'INPUT' in self.lesion else torch.sigmoid(i)
            o = self.open if 'OUTPUT' in self.lesion else torch.sigmoid(o) 

            c = torch.tanh(c)
            LTM = i * c + LTM * f  
            STM = torch.tanh(LTM) * o
            LTM = LTM.unsqueeze(0)
            STM = STM.unsqueeze(0)
        
        if self.override == 'FF':
            STM = self.pavlovian(inp[:,0]).unsqueeze(0).unsqueeze(0)
            LTM, f, i, c, o = self.empty_layers()

        if self.override == 'FF_mem':
            self.memory = torch.roll(self.memory, -1, 1)
            self.memory[:, -1] = inp[:, 0]     
            STM = self.pavlovian(self.memory).unsqueeze(0)
            LTM, f, i, c, o = self.empty_layers()
                             
        return STM, (STM, LTM), torch.cat((f, i, c, o), 0)
    
    def empty_layers(self):
        return self.closed.unsqueeze(0), self.open, self.open, self.open, self.open
            
