import random
import numpy as np
from tqdm import trange
import copy

import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

torch.cuda.empty_cache()

import gc
gc.collect()


class SimpleModel(nn.Module):
    def __init__(self, initial_layer_size, first_hidden_layer_size, second_hidden_layer_size, third_hidden_layer_size):
        """
        Initializes the deep neural network model layers.
        """
        super(SimpleModel, self).__init__()

        self.initial_layer_size = initial_layer_size
        self.first_hidden_layer_size = first_hidden_layer_size
        self.second_hidden_layer_size = second_hidden_layer_size
        self.third_hidden_layer_size = third_hidden_layer_size

        self.fc1 = nn.Linear(initial_layer_size, first_hidden_layer_size)
        self.fc2 = nn.Linear(first_hidden_layer_size, second_hidden_layer_size)
        self.fc3 = nn.Linear(second_hidden_layer_size, third_hidden_layer_size)
        self.fc4 = nn.Linear(third_hidden_layer_size, 1)       

    def forward(self, x):
        """
        Defines the forward pass of the neural network using ReLU activations.
        """
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.relu(self.fc3(x))
        x = torch.sigmoid(self.fc4(x))
        return x

    def reparameterized_forward(self, x):
        """
        Defines an alternative forward pass where weights are transformed before being applied for computing the connectivity regularizer.
        """
        transformed_weight = self.transform_tensor(self.fc1.weight)
        x = F.linear(x, transformed_weight, bias=None)

        transformed_weight = self.transform_tensor(self.fc2.weight)
        x = F.linear(x, transformed_weight, bias=None)

        transformed_weight = self.transform_tensor(self.fc3.weight)
        x = F.linear(x, transformed_weight, bias=None)

        transformed_weight = self.transform_tensor(self.fc4.weight)
        x = F.linear(x, transformed_weight, bias=None)
        return x
    
    def synflow_forward(self, x):
        """
        Defines an alternative forward pass where weights are transformed before being applied. Used for SynFlow importance scores.
        """
        transformed_weight = torch.abs(self.fc1.weight)
        x = F.linear(x, transformed_weight, bias=None)

        transformed_weight = torch.abs(self.fc2.weight)
        x = F.linear(x, transformed_weight, bias=None)

        transformed_weight = torch.abs(self.fc3.weight)
        x = F.linear(x, transformed_weight, bias=None)

        transformed_weight = torch.abs(self.fc4.weight)
        x = F.linear(x, transformed_weight, bias=None)
        return x

    def transform_tensor(self, param):
        """
        Takes the absolute value of a tensor and normalizes it.
        """
        v = torch.abs(param)
        v_normalized = v / torch.sum(v)
        return v_normalized


def generate_data(train_size, batch_size, initial_layer_size, noise_std):
    """
    Generates synthetic training and testing datasets, converts them into PyTorch tensors, 
    and returns data loaders for use in model training.
    """

    X_train = np.random.normal(0, input_std, size=(train_size, initial_layer_size)).astype(np.float32)
    X_test = np.random.normal(0, input_std, size=(test_size, initial_layer_size)).astype(np.float32)

    y_train = (X_train[:, 0] + X_train[:, 1] + np.random.normal(0, noise_std, size=(train_size)) > 0).astype(np.float32)
    y_test = (X_test[:, 0] + X_test[:, 1] + np.random.normal(0, noise_std, size=(test_size))> 0).astype(np.float32)


    X_train_tensor = torch.tensor(X_train)
    y_train_tensor = torch.tensor(y_train).view(-1, 1)
    X_test_tensor = torch.tensor(X_test)
    y_test_tensor = torch.tensor(y_test).view(-1, 1)

    X_train_tensor = X_train_tensor.to(device)
    y_train_tensor = y_train_tensor.to(device)
    X_test_tensor = X_test_tensor.to(device)
    y_test_tensor = y_test_tensor.to(device)


    train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    test_dataset = TensorDataset(X_test_tensor, y_test_tensor)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, test_loader


def run_experiment(nn_model, experiment_id, lr=0.01, weight_decay=5e-4, epochs=100, epochs_wo_con=10, l1_lambda=0.01, con_lambda=0.001, initial_layer_size=6, prune_percentage=85, batch_size=256):
    """
    Runs the entire experiment, including training, pruning and finetuning, and returns the final pruned accuracies.
    """

    def train_model(model, optimizer, train_loader, test_loader, epochs=200, l1_lambda=0, con_lambda=0, epochs_wo_con=0):
        """
        Trains the model for a specified number of epochs with optional L1 and connectivity regularization.
        """
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
        
        loss_values = []
        accuracy_values = []
        for epoch in trange(epochs):
            model.train()
            
            running_loss = 0.0
            for X_batch, y_batch in train_loader:
                optimizer.zero_grad()
                outputs = model(X_batch)
                loss = criterion(outputs, y_batch)

                if con_lambda > 0 and epoch >= epochs_wo_con:
                    reparam_output = model.reparameterized_forward( torch.ones(1, initial_layer_size).to(device) )
                    regularization = torch.sum(reparam_output)
                    reg_transform = -1 * torch.log(regularization)
                    loss += con_lambda * reg_transform

                if l1_lambda > 0:
                    l1_penalty = 0
                    for name, param in model.named_parameters():
                        if 'weight' in name and 'bn' not in name:
                            l1_penalty += torch.sum(torch.abs(param))
                    loss += l1_lambda * l1_penalty

                loss.backward()
                optimizer.step()
                
                running_loss += loss.item()
            
            scheduler.step()

            if epoch == epochs - 1:
                epochs_ret = retraining_epochs
                pruned_model, acc, retrained_loss_values, retrained_accuracy_values = prune_and_validate(copy.deepcopy(model), prune_percentage, train_loader, test_loader, epochs_ret)
            else:
                epochs_ret = 0
                acc = evaluate_model(model, test_loader)
            
            average_loss = running_loss / len(train_loader)
            loss_values.append(average_loss)
            accuracy_values.append(acc)
            
        return pruned_model, loss_values, accuracy_values, retrained_loss_values, retrained_accuracy_values


    def evaluate_model(model, test_loader):
        """
        Evaluates the model on the test dataset and calculates the accuracy.
        """
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for X_batch, y_batch in test_loader:
                outputs = model(X_batch)
                predicted = (outputs >= 0.5).float()
                correct += (predicted == y_batch).sum().item()
                total += y_batch.size(0)
        accuracy = correct / total
        model.train()
        return accuracy


    def prune_and_validate(model, prune_percentage, train_loader, test_loader, epochs_ret):
        """
        Prunes the model weights, retrains the model with the pruned weights, and returns the accuracies before and after pruning.
        """
        null_accuracy = evaluate_model(model, test_loader)
        if epochs_ret == 0:
            return null_accuracy
        else:
            if prune_strategy == 'magnitude':
                model, masks = magnitude_pruning_local(model)
            elif prune_strategy == 'synflow':
                model, masks = synflow_pruning_local(model, prune_rate=prune_percentage)
            else:
                raise ValueError("Pruning strategy unknown. Choose 'magnitude' or 'synflow'.")
            
            pruned_model = copy.deepcopy(model)

            optimizer = optim.Adam(pruned_model.parameters(), lr=ret_lr, weight_decay=weight_decay)
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
            retrained_model, retrained_loss_values, retrained_accuracy_values = retrain_model_with_masks(pruned_model, optimizer, scheduler, train_loader, test_loader, masks, epochs_ret)

            return retrained_model, null_accuracy, retrained_loss_values, retrained_accuracy_values


    def magnitude_pruning_local(model):
        """
        Applies magnitude-based pruning to the weights and returns the pruned model along with the pruning masks.
        """
        masks = {}
        for name, param in model.named_parameters():
            if 'weight' in name and 'bn' not in name:           

                weights_np = param.detach().cpu().numpy().flatten()
                num_weights_to_prune = int(prune_percentage * len(weights_np))
                
                sorted_weights = np.sort(np.abs(weights_np))
                threshold = sorted_weights[num_weights_to_prune]
                
                mask = torch.abs(param) >= threshold
                param.data *= mask.float()
                masks[name] = mask
        
        return model, masks


    def synflow_pruning_local(model, prune_rate):
        """
        Applies SynFlow pruning to the model by computing importance scores based on the gradients and weights, then creates masks to prune a specified percentage of parameters.
        """
        model.eval() 

        output = model.synflow_forward( torch.ones([1, initial_layer_size]).to(device) )

        loss = output.sum()
        model.zero_grad()
        loss.backward()

        masks = {}

        for name, param in model.named_parameters():
            if param.requires_grad and "bias" not in name and "bn" not in name.lower():

                importance_score = param.grad.abs() * param.abs()
                flat_importance_score = importance_score.view(-1)

                num_prune = int(prune_rate * flat_importance_score.size(0))

                mask = torch.ones_like(flat_importance_score)
                if num_prune > 0:
                    prune_indices = torch.argsort(flat_importance_score)[:num_prune]
                    mask[prune_indices] = 0

                mask = mask.view_as(param.data)
                param.data.mul_(mask)
                masks[name] = mask

        return model, masks


    def synflow_pruning_global(model, prune_rate):
        """
        Applies SynFlow pruning globally by computing importance scores across the entire model, 
        then pruning a specified percentage of parameters globally.
        """
        model.eval() 
        
        output = model.synflow_forward( torch.ones([1, initial_layer_size]).to(device) )
        
        loss = output.sum()
        model.zero_grad()
        loss.backward()

        all_importance_scores = []
        param_shapes = {} 
        param_tensors = {}  

        for name, param in model.named_parameters():
            if param.requires_grad and "bias" not in name and "bn" not in name.lower():

                importance_score = param.grad.abs() * param.abs()

                flat_importance_score = importance_score.view(-1)
                all_importance_scores.append(flat_importance_score)

                param_shapes[name] = param.shape
                param_tensors[name] = param

        global_importance_scores = torch.cat(all_importance_scores)
        num_params_to_prune = int(prune_rate * global_importance_scores.numel())

        if num_params_to_prune > 0:
            threshold_score = torch.topk(global_importance_scores, num_params_to_prune, largest=False).values.max()
        else:
            threshold_score = float('inf')

        masks = {}
        for name, param in param_tensors.items():

            importance_score = param.grad.abs() * param.abs()

            mask = (importance_score >= threshold_score).float()
            param.data.mul_(mask)
            masks[name] = mask

        return model, masks


    def retrain_model_with_masks(model, optimizer, scheduler, train_loader, test_loader, masks, epochs=500):
        """
        Retrains the model using the provided masks to prevent gradient updates on pruned weights.
        """
        loss_values = []
        accuracy_values = []
        model.train()
        for _ in trange(epochs):
            epoch_loss = 0.0
            for batch in train_loader:
                X_batch, y_batch = batch
                optimizer.zero_grad()
                outputs = model(X_batch)
                loss = criterion(outputs, y_batch)
                loss.backward()

                for name, param in model.named_parameters():
                    if name in masks:
                        param.grad *= masks[name]

                optimizer.step()
                epoch_loss += loss.item()

            scheduler.step()

            loss_values.append(epoch_loss / len(train_loader))

            accuracy = evaluate_model(model, test_loader)
            accuracy_values.append(accuracy)
        return model, loss_values, accuracy_values


    train_loader, test_loader = generate_data(train_size, batch_size, initial_layer_size, noise_std)

    models = []
    model_labels = []

    if include_no_l1:
        model_no_l1 = copy.deepcopy(nn_model)
        optimizer_no_l1 = optim.Adam(model_no_l1.parameters(), lr=lr, weight_decay=weight_decay)
        models.append((model_no_l1, optimizer_no_l1, {}, 'No Regularization'))

    if include_l1:
        model_l1 = copy.deepcopy(nn_model)
        optimizer_l1 = optim.Adam(model_l1.parameters(), lr=lr, weight_decay=weight_decay)
        models.append((model_l1, optimizer_l1, {'l1_lambda': l1_lambda}, 'L1 Regularization'))

    if include_con:
        model_con = copy.deepcopy(nn_model)
        optimizer_con = optim.Adam(model_con.parameters(), lr=lr, weight_decay=weight_decay)
        models.append((model_con, optimizer_con, {'l1_lambda': 0., 'con_lambda': con_lambda, 'epochs_wo_con': epochs_wo_con}, 'CoNNect Regularization'))

    loss_values = []
    accuracies = []
    retrained_loss_values = []
    retrained_accuracy_values = []

    for model, optimizer, extra_args, label in models:
        print(f'Train (and finetune) model with {label}...')
        prun_model, loss_vals, acc_vals, retrain_loss_vals, retrain_acc_vals = train_model(model, optimizer, train_loader, test_loader, epochs=epochs, **extra_args)
        
        loss_values.append(loss_vals)
        accuracies.append(acc_vals)
        retrained_loss_values.append(retrain_loss_vals)
        retrained_accuracy_values.append(retrain_acc_vals)
        model_labels.append(label)

        print(f'Finetuned accuracy {label}: {retrain_acc_vals[-1]}')


if __name__ == '__main__':

    torch.manual_seed(1)
    np.random.seed(1)
    random.seed(1)

    torch.set_num_threads(16)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Flags to include models
    include_no_l1, include_l1, include_con = True, True, True
    
    # Training data constant
    batch_size = 256
    train_size = batch_size * 8
    test_size = batch_size * 2
    input_std = 2
    noise_std = 0.25

    # Algorithm constants
    lr = 0.01 # learning rate
    ret_lr = lr * 0.1 # learning rate for retraining
    weight_decay = 5e-4 # decay factor for weights

    # Objective constants
    criterion = nn.BCELoss()
    l1_lambda = 0.001 # strength L1 regularizer
    con_lambda = 0.1 # strength CoNNect regularizer

    # Model architechture
    initial_layer_size = 6
    first_hidden_layer_size = 5
    second_hidden_layer_size = 5
    third_hidden_layer_size = 5

    # Run the experiment
    epochs = 200 # Number of epochs in experiment
    epochs_wo_con = int(0*epochs) # Number of epochs without connectivity regularizer
    retraining_epochs = 50 # Number of epochs for retraining.
    prune_percentage = 0.96 # percentage of weights to be pruned after training
    prune_strategy = 'magnitude' # choose 'magnitude' / 'synflow'

    # Repeat experiment
    results = [[] for _ in range(sum([include_no_l1, include_l1, include_con]))] # store test accuracy for pruned model after finetuning, i.e., retraining.
    
    # Set up model
    nn_model = SimpleModel(initial_layer_size, first_hidden_layer_size, second_hidden_layer_size, third_hidden_layer_size).to(device)
    
    # Train and finetune
    run_experiment(nn_model, 
                   experiment_id=1, 
                   lr=lr, weight_decay=weight_decay, 
                   epochs=epochs, 
                   epochs_wo_con=epochs_wo_con, 
                   l1_lambda=l1_lambda, 
                   con_lambda=con_lambda, 
                   initial_layer_size=initial_layer_size, 
                   prune_percentage=prune_percentage)