import torch.nn as nn
import torch 
import pdb 
import numpy as np

def combine_means(mu1, mu2, m, n):
    """
    Updates old mean mu1 from m samples with mean mu2 of n samples.
    Returns the mean of the m+n samples.
    """
    return (m/(m+n))*mu1 + (n/(m+n))*mu2

def combine_vars(v1, v2, mu1, mu2, m, n):
    """
    Updates old variance v1 from m samples with variance v2 of n samples.
    Returns the variance of the m+n samples.
    """
    return (m/(m+n))*v1 + n/(m+n)*v2 + m*n/(m+n)**2 * (mu1 - mu2)**2

class MLP(nn.Module):
    def __init__(
        self, 
        num_layers, 
        width, 
        c, 
        weight_distribution,
        weight_gain,
        bias_distribution,
        bias_gain,
        train_weights=False,
        input_layer_bias=True,
        output_layer_bias=True,
        middle_layers_bias=True,
        l1_weight=0,
        bias_l1_weight=0,
        bias_l1_baseline=-1
        ):
        super(MLP, self).__init__()
        self.input_layer = nn.Linear(28 * 28, width, bias=input_layer_bias)
        self.layers = nn.ModuleList()
        self.output_layer = nn.Linear(width, 10, bias=output_layer_bias)
        if train_weights:
            print("Training weights")
        else:
            print("Not training weights")

        for _ in range(num_layers):
            layer = nn.Linear(width, width, bias=middle_layers_bias)
            if train_weights:
                nn.init.xavier_uniform_(layer.weight)
            else:
                if weight_distribution == "uniform":
                    layer.weight.data.uniform_(-c, c)
                elif weight_distribution == "normal":
                    layer.weight.data.normal_(-c, c)
                elif weight_distribution == "saxe":
                    nn.init.orthogonal_(layer.weight, gain=weight_gain)

                layer.weight.requires_grad = False
            if middle_layers_bias:
                if bias_distribution == "xavier":
                    nn.init.xavier_uniform_(layer.bias)
                elif bias_distribution == "saxe":
                    nn.init.orthogonal_(layer.bias, gain=bias_gain)
                elif bias_distribution == "none":
                    pass
            self.layers.append(layer)

        if bias_distribution == "xavier":
            if input_layer_bias:
                nn.init.xavier_uniform_(self.input_layer.bias)
            if output_layer_bias:
                nn.init.xavier_uniform_(self.output_layer.bias)
        elif bias_distribution == "saxe":
            if input_layer_bias:
                nn.init.orthogonal_(self.input_layer.bias,  gain=bias_gain)
            if output_layer_bias:
                nn.init.orthogonal_(self.output_layer.bias, gain=bias_gain)
        elif bias_distribution == "none":
            pass

        if train_weights:
            nn.init.xavier_uniform_(self.input_layer.weight)
            nn.init.xavier_uniform_(self.output_layer.weight)
        else:
            if weight_distribution == "uniform":
                self.input_layer.weight.data.uniform_(-c, c)
                self.output_layer.weight.data.uniform_(-c, c)
            elif weight_distribution == "normal":
                self.input_layer.weight.data.normal_(0, c)
                self.output_layer.weight.data.normal_(0, c)
            elif weight_distribution == "saxe":
                nn.init.orthogonal_(self.input_layer.weight, gain=weight_gain)
                nn.init.orthogonal_(self.output_layer.weight, gain=weight_gain)
            self.input_layer.weight.requires_grad = False
            self.output_layer.weight.requires_grad = False
        
        self.l1_weight = l1_weight
        self.bias_l1_weight = bias_l1_weight
        self.bias_l1_baseline = bias_l1_baseline

    def set_weights(self, weights_path):
        weights = torch.load(weights_path)
        self.input_layer.weight.data = weights[0]
        for i, layer in enumerate(self.layers):
            layer.weight.data = weights[i + 1]
        self.output_layer.weight.data = weights[-1]
    
    def disable_weights_training(self):
        self.input_layer.weight.requires_grad = False
        for layer in self.layers:
            layer.weight.requires_grad = False
        self.output_layer.weight.requires_grad = False

    def only_finetune_output(self):
        self.input_layer.weight.requires_grad = False
        if self.input_layer.bias is not None:
            self.input_layer.bias.requires_grad = False

        for layer in self.layers:
            layer.weight.requires_grad = False
            if layer.bias is not None:
                layer.bias.requires_grad = False

        self.output_layer.weight.requires_grad = True
        if self.output_layer.bias is not None:
            self.output_layer.bias.requires_grad = True

    def set_biases(self, biases_path):
        biases = torch.load(biases_path)
        if self.input_layer.bias is not None:
            self.input_layer.bias.data = biases[0]
        for i, layer in enumerate(self.layers):
            if layer.bias is not None:
                layer.bias.data = biases[i + 1]
        if self.output_layer.bias is not None:
            self.output_layer.bias.data = biases[-1]

    def save_weights(self, path):
        weights = self.get_weights_and_biases()["weights"]
        torch.save(weights, path)

    def save_biases(self, path):
        biases = self.get_weights_and_biases()["biases"]
        torch.save(biases, path)
    
    def get_weights_and_biases(self):
        biases = []
        if self.input_layer.bias is not None:
            biases.append(self.input_layer.bias)
        # for layer in self.layers:
        #     if layer.bias is not None:
        #         biases.append(layer.bias)
        # if self.output_layer.bias is not None:
        #     biases.append(self.output_layer.bias)
        return {
            "weights": [self.input_layer.weight] + [layer.weight for layer in self.layers] + [self.output_layer.weight],
            "biases": biases
        }

    def forward(self, x):
        x = torch.relu(self.input_layer(x))
        intermediate = [x]
        for layer in self.layers:
            x = torch.relu(layer(x))
            intermediate.append(x)
        x = self.output_layer(x)
        return x, torch.cat(intermediate, dim=1)

    def get_max_bias(self):
        max_bias = -1000000

        # for layer in self.layers:
        #     if layer.bias != None:
        #         max_bias = max(max_bias, torch.max(layer.bias).item())

        # if self.output_layer.bias != None:
        #     max_bias = max(max_bias, torch.max(self.output_layer.bias).item())
        if self.input_layer.bias != None:
            max_bias = max(max_bias, torch.max(self.input_layer.bias).item())

        return max_bias

    def get_biases(self):
        biases = []
        if self.input_layer.bias is not None:
            biases.append(self.input_layer.bias)
        for layer in self.layers:
            if layer.bias is not None:
                biases.append(layer.bias)
        if self.output_layer.bias is not None:
            biases.append(self.output_layer.bias)
        return torch.cat(biases).detach().numpy()

    def get_biases_tensor(self):
        biases = []
        if self.input_layer.bias is not None:
            biases.append(self.input_layer.bias)
        for layer in self.layers:
            if layer.bias is not None:
                biases.append(layer.bias)
        if self.output_layer.bias is not None:
            biases.append(self.output_layer.bias)
        return torch.cat(biases)

    def get_min_bias(self):
        min_bias = 1000000
        
        # for layer in self.layers:
        #     if layer.bias != None:
        #         min_bias = min(min_bias, torch.min(layer.bias).item())

        # if self.output_layer.bias != None:
        #     min_bias = min(min_bias, torch.min(self.output_layer.bias).item())

        if self.input_layer.bias != None:
            min_bias = min(min_bias, torch.min(self.input_layer.bias).item())

        return min_bias


    def drop_units_below(self, threshold):
        for layer in self.layers:
            if layer.bias != None:
                layer.weight.data[layer.bias < threshold] = 0
                layer.bias.data[layer.bias < threshold] = 0
            
        # if self.output_layer.bias != None:
        #     self.output_layer.weight.data[self.output_layer.bias < threshold] = 0
        num_biases_deleted = torch.sum(self.input_layer.bias < threshold).item()
        
        if self.input_layer.bias != None:
            self.input_layer.weight.data[self.input_layer.bias < threshold] = 0
            self.input_layer.bias.data[self.input_layer.bias < threshold] = 0

        return num_biases_deleted

    def train_epoch(self, trainloader, optimizer, criterion, noise_variance=0):
        self.train()
        train_running_loss = 0.0
    
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data
            noise = torch.randn_like(inputs) * noise_variance
            inputs = inputs + noise
            optimizer.zero_grad()
            outputs, intermediate = self.forward(inputs)
            train_loss = criterion(outputs, labels)
            if self.l1_weight != 0:
                l1_loss = self.l1_weight * torch.sum(torch.abs(intermediate))/inputs.shape[0]
                train_loss += l1_loss
            
            if self.bias_l1_weight != 0:
                bias_l1_loss = self.bias_l1_weight * torch.sum(self.get_biases_tensor())
                train_loss += bias_l1_loss

            train_loss.backward()
            optimizer.step()
            train_running_loss += train_loss.item()
        
        return train_running_loss / len(trainloader)

    def eval_epoch(self, testloader, criterion, noise_variance=0):
        self.eval()
        correct = 0
        total = 0
        val_running_loss = 0.0
        with torch.no_grad():
            for data in testloader:
                inputs, labels = data
                noise = torch.randn_like(inputs) * noise_variance
                inputs = inputs + noise

                outputs, intermediate = self.forward(inputs)
                val_loss = criterion(outputs, labels)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
                val_running_loss += val_loss.item()

        return val_running_loss / len(testloader), correct / total, inputs

    def eval_epoch_with_variances(self, testloader, criterion, noise_variance=0):
        self.eval()
        correct = 0
        total = 0
        val_running_loss = 0.0
        prev_mean = None
        prev_var = None
        num_samples_so_far = 0
        labels_list = []
        intermediates = []
        with torch.no_grad():
            for i, data in enumerate(testloader):
                inputs, labels = data
                noise = torch.randn_like(inputs) * noise_variance
                inputs = inputs + noise

                outputs, intermediate = self.forward(inputs)
                intermediates.append(intermediate)
                labels_list.append(labels)
                # if label_corr is None:
                #     label_corr = torch.zeros(intermediate.shape[1])
                # for j in range(intermediate.shape[1]):
                #     label_corr[j] += np.abs(np.corrcoef(intermediate[:, j], labels)[0, 1])

                
                # TODO: Get correlation between intermediate and labels
                # Get correlation between intermediate and input statistics
                if prev_mean is None:
                    prev_mean = torch.mean(intermediate, dim=0)
                    prev_var = torch.var(intermediate, dim=0)
                    num_samples_so_far = inputs.shape[0]
                
                else:
                    new_mean = torch.mean(intermediate, dim=0)
                    new_var = torch.var(intermediate, dim=0)
                    prev_var = combine_vars(prev_var, new_var, prev_mean, new_mean, num_samples_so_far, inputs.shape[0])
                    prev_mean = combine_means(prev_mean, new_mean, num_samples_so_far, inputs.shape[0])
                    num_samples_so_far += inputs.shape[0]
                
                val_loss = criterion(outputs, labels)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
                val_running_loss += val_loss.item()
        label_corr = torch.zeros(intermediate.shape[1])
        intermediates = np.concatenate(intermediates, axis=0)
        labels_list = np.concatenate(labels_list, axis=0)
        for j in range(intermediate.shape[1]):
            label_corr[j] = np.corrcoef(intermediates[:, j], labels_list)[0, 1]

        return val_running_loss / len(testloader), correct / total, inputs, prev_var, prev_mean, label_corr

def combine_means(mu1, mu2, m, n):
    """
    Updates old mean mu1 from m samples with mean mu2 of n samples.
    Returns the mean of the m+n samples.
    """
    return (m/(m+n))*mu1 + (n/(m+n))*mu2

def combine_vars(v1, v2, mu1, mu2, m, n):
    """
    Updates old variance v1 from m samples with variance v2 of n samples.
    Returns the variance of the m+n samples.
    """
    return (m/(m+n))*v1 + n/(m+n)*v2 + m*n/(m+n)**2 * (mu1 - mu2)**2

