#%%
import torch as to
import torch.nn as nn
from torch.utils.data import DataLoader

import wandb as wb


def sigmoid(x):
    """
    functional form of pytorch sigmoid
    """

    fn = to.nn.Sigmoid()

    return fn(x)


def relu(x):
    """
    functional form of pytorch relu 
    """

    fn = to.nn.ReLU()

    return fn(x)


def tanh(x):
    """
    functional form of pytorch relu 
    """

    fn = to.nn.Tanh()

    return fn(x)

def uniform(a, b, size, gen):

    return (a - b) * to.rand(size=size, generator=gen) + b

def normal(mu, sig, size, gen):

    return to.normal(mean=mu * to.ones(size=size), std=sig, generator=gen)

def bimodal(mu, sig, size, gen):

    b1 = to.bernoulli(input=0.5 * to.ones(size=size), generator=gen)
    b2 = b1 - 1
    b = b1 + b2
    mu_array = mu * b

    return to.normal(mean=mu_array, std=sig, generator=gen)

def dl_normal(mu, sig, size, gen):

    b1 = to.bernoulli(input=0.5 * to.ones(size=size), generator=gen)
    b2 = b1 - 1
    b = b1 + b2

    return b * to.log(to.normal(mean=mu * to.ones(size=size), std=sig, generator=gen))


class NN(nn.Module):
    def __init__(self, n_input, n_hidden, n_output, learning_rate=0.01, to_learn='biases', crit=nn.MSELoss(), activation_type='sigmoid'):
        super().__init__()
        self.T1 = nn.Linear(n_input, n_hidden, bias=True)
        self.T2 = nn.Linear(n_hidden, n_output, bias=False)

        if to_learn == 'biases':
            self.optimizer = to.optim.Adam([self.T1.bias], lr=learning_rate)
            self.alphas = to.empty(1)
        elif to_learn == 'softmask':
            self.alphas = nn.Parameter(to.zeros(n_hidden))
            self.optimizer = to.optim.Adam([self.alphas], lr=learning_rate)
        elif to_learn == 'all':
            self.optimizer = to.optim.Adam(self.parameters(), lr=learning_rate)
            self.alphas = to.empty(1)

        
        self.loss_history = []
        self.n_hidden = n_hidden
        self.n_input = n_input
        self.n_output = n_output
        self.tau = None
        self.to_learn = to_learn
        self.wandb = False

        self.loss = crit
        
        if activation_type == 'sigmoid':
            self.activation = sigmoid
        elif activation_type =='relu':
            self.activation = relu
        elif activation_type == 'tanh':
            self.activation = tanh
        else:
            return 'unknown activation type'
        self.activation_name = activation_type

    def init_uniform(self, a, b, generator):
        with to.no_grad():
            self.alphas.copy_(uniform(a, b, size=self.alphas.shape, gen=generator))
            self.T1.bias.copy_(uniform(a, b, size=self.T1.bias.shape, gen=generator))

    def init_weights_uniform(self, a, b, generator):
        with to.no_grad():
            self.T1.weight.copy_(uniform(a, b, size=self.T1.weight.shape, gen=generator))
            self.T2.weight.copy_(uniform(a, b, size=self.T2.weight.shape, gen=generator))

    def init_weights_normal(self, mu, sig, generator):
        with to.no_grad():
            self.T1.weight.copy_(normal(mu, sig, size=self.T1.weight.shape, gen=generator))
            self.T2.weight.copy_(normal(mu, sig, size=self.T2.weight.shape, gen=generator))
    
    def init_weights_bimodal(self, mu, sig, generator):
        with to.no_grad():
            self.T1.weight.copy_(bimodal(mu, sig, size=self.T1.weight.shape, gen=generator))
            self.T2.weight.copy_(bimodal(mu, sig, size=self.T2.weight.shape, gen=generator))
    
    def init_weights_dln(self, mu, sig, generator):
        with to.no_grad():
            self.T1.weight.copy_(dl_normal(mu, sig, size=self.T1.weight.shape, gen=generator))
            self.T2.weight.copy_(dl_normal(mu, sig, size=self.T2.weight.shape, gen=generator))
    
    def forward(self, x):
        if self.to_learn == 'softmask':
            return self.T2(self.activation(self.T1(x)) * sigmoid(self.alphas / self.tau))
        elif self.to_learn == 'biases':
            return self.T2(self.activation(self.T1(x)))
        elif self.to_learn == 'all':
            return self.T2(self.activation(self.T1(x)))
    
    def mask_forward(self, x, mask):

        return self.T2(self.activation(self.T1(x)) * mask)
    
    def hidden_states(self, x):

        if self.to_learn == 'softmask':
            return self.activation(self.T1(x)) * sigmoid(self.alphas / self.tau)
        elif self.to_learn == 'biases':
            return self.activation(self.T1(x))
        elif self.to_learn == 'all':
            return self.activation(self.T1(x))
    
    def backprop(self, loss):
        # Backpropagation
        loss.backward()
        self.optimizer.step()
    
    def print_update(self, size, batch, batch_size, loss, epoch):
        if batch % 10 == 0:
            self.loss_history.append(loss.item())
            loss, current = loss.item(), (batch + 1) * batch_size

            if current < size:
                print(f"epoch: {epoch}; loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
            else:
                print(f"epoch: {epoch}; loss: {loss:>7f}  [{size:>5d}/{size:>5d}]")
            if self.wandb:
                wb.log({'epoch': epoch, 'loss': self.loss_history[-1]})

    
    def train_loop(self, dataloader, epoch=1):
        size = len(dataloader.dataset)

        for batch, (x, y) in enumerate(dataloader):

            # Compute prediction and loss
            self.optimizer.zero_grad()
            out = self.forward(x)
            loss = self.loss(out, y)
            self.backprop(loss)
            self.print_update(size, batch, dataloader.batch_size, loss, epoch)
    
    def mytrain(self, dataloader, n_epochs=1, tau=None, wandb=False):
        self.train()
        self.wandb = wandb

        if tau is None:
            tau = to.ones(n_epochs)

        for epoch in range(n_epochs):
            self.tau = tau[epoch]
            self.train_loop(dataloader, epoch)
    
    def evaluate(self, testloader):
        self.eval()
        correct = 0
        total = 0
        with to.no_grad():
            for (inputs, labels) in testloader:
                outputs = self.forward(inputs)
                _, predicted = to.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels[:, 0]).sum()

        return correct / total
    
    def mask_evaluate(self, testloader, mask):
        self.eval()
        correct = 0
        total = 0
        with to.no_grad():
            for (inputs, labels) in testloader:
                outputs = self.mask_forward(inputs, mask)
                _, predicted = to.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels[:, 0]).sum()

        return correct / total

    def get_hidden_states(self, dataset):
        self.eval()
        N = len(dataset)
        loader = DataLoader(dataset, batch_size=1)

        hs = to.zeros((N, self.n_hidden))
        with to.no_grad():
            for ix, (inputs, _) in enumerate(loader):
                hs[ix, :] = self.hidden_states(inputs)

        return hs

