from abc import ABC, abstractmethod
import numpy as np
import copy
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
from tqdm import tqdm
import random

from training import train, eval
from algorithms import *

device = 'cuda' if torch.cuda.is_available() else 'cpu'

class pruning_base(ABC):
    def __init__(self, model, info, dataset, method='neural_path_kmeans'):
        self.model = model
        self.info = {'name' : info['name'], 'dataset' : info['dataset'], 'imsize' : info['imsize'], 'method' : method}
        self.dataset = dataset
        
    @abstractmethod
    def prune_layer(self, i, layer1, layer2, print_upper_bound):
        pass

    def prune_net(self, f, w2_rescale=False, print_upper_bound=False):
        self.layers = [layer for layer in self.model.classifier if type(layer) == nn.modules.linear.Linear]
        self.layer_indices = [i for i in range(len(self.model.classifier)) if type(self.model.classifier[i]) == nn.modules.linear.Linear]

        for i in range(len(self.layers) - 1):
            layer1 = self.layers[i]
            layer2 = self.layers[i + 1]

            self.n = layer1.weight.shape[0] # hidden layer dimension
            self.d = layer1.weight.shape[1] # input dimension
            self.m = layer2.weight.shape[0] # output dimension
            self.K = int(f * self.n) # reduced hidden layer dimension
            self.f = f

            w1, b1, w2 = self.prune_layer(i, layer1, layer2, print_upper_bound)
            
            if (w2_rescale):
                half_model_original, half_model_compressed = self.half_model(self.model, self.layer_indices, i, w1, b1)
                
                w2 = weight_rescale(half_model_original, half_model_compressed, self.dataset, 
                            layer2.weight, w2, self.n, self.m, self.K)

            self.substitute_weights(f, layer1, layer2, w1, b1, w2)


    def substitute_weights(self, f, layer1, layer2, w1, b1, w2):
        with torch.no_grad():
            layer1.out_features = int(f * layer1.out_features)
            layer2.in_features = int(f * layer2.in_features)

            layer1.weight = nn.Parameter(w1.float())
            layer1.bias = nn.Parameter(b1.float())
            layer2.weight = nn.Parameter(w2.float())


    def half_model(self, model, layer_indices, i, w1_comp, b1_comp):
        half_model_original = copy.deepcopy(model.features)
        half_model_original.add_module('flatten', nn.Flatten())
        for j in range(layer_indices[i] + 2):
            half_model_original.add_module('layer_{}'.format(j), model.classifier[j])
        
        if (self.info['name'] == 'CIFAR-VGG'):
            half_model_original.add_module('layer_{}'.format(j), model.classifier[layer_indices[i] + 2])

        half_model_compressed = copy.deepcopy(half_model_original)
        half_model_compressed[len(model.features) + 1 + layer_indices[i]].weight = nn.Parameter(w1_comp.float())
        half_model_compressed[len(model.features) + 1 + layer_indices[i]].bias = nn.Parameter(b1_comp.float())
        
        if (self.info['name'] == 'CIFAR-VGG'):
            with torch.no_grad():
                half_model_compressed[len(model.features) + 2 + layer_indices[i]] = nn.BatchNorm1d(self.K)

        half_model_compressed.to(device)

        
        return half_model_original, half_model_compressed


class zonotope_kmeans_pruning(pruning_base):
    def prune_layer(self, i, layer1, layer2, print_upper_bound=False):
        points = torch.cat((layer1.weight, layer1.bias.reshape(-1, 1)), dim=1)
        coef = layer2.weight.flatten()
            
        pos_points = [(points[i] * coef[i]).tolist() for i in range(points.shape[0]) if coef[i] > 0]
        neg_points = [(points[i] * -coef[i]).tolist() for i in range(points.shape[0]) if coef[i] < 0]
        
        pos_weights, pos_bias = zonotope_kmeans(pos_points, self.f, self.info, print_upper_bound=print_upper_bound)
        neg_weights, neg_bias = zonotope_kmeans(neg_points, self.f, self.info, print_upper_bound=print_upper_bound)

        w1 = torch.cat((pos_weights, neg_weights), dim=0)
        b1 = torch.cat((pos_bias, neg_bias), dim=0)
        w2 = torch.tensor([[1.0 for i in range(pos_weights.shape[0])] + [-1.0 for i in range(neg_weights.shape[0])]])

        return w1, b1, w2


class neural_path_kmeans_pruning(pruning_base):
    def prune_layer(self, i, layer1, layer2, print_upper_bound=False):
        points = torch.cat((layer1.weight, layer1.bias.reshape(-1, 1), torch.transpose(layer2.weight, 0, 1)), dim=1)
        w1, b1, w2 = neural_path_kmeans(points.cpu().detach(), self.d, self.m, self.f, self.info, print_upper_bound=print_upper_bound)
        w2 = torch.transpose(w2, 0, 1).float()

        return w1, b1, w2


class thinet_pruning(pruning_base):
    def prune_layer(self, i, layer1, layer2, print_upper_bound=False):
        half_model, _ = self.half_model(self.model, self.layer_indices, i, layer1.weight, layer1.bias)
        w1, b1, w2 = thinet(half_model, self.dataset, layer1, layer2, self.K)
        return w1, b1, w2


class random_pruning(pruning_base):
    def prune_layer(self, i, layer1, layer2, print_upper_bound=False, criterion='random'):
        neurons = random.sample(range(self.n), self.K)

        w1, b1, w2 = layer1.weight[neurons, :], layer1.bias[neurons], layer2.weight[:, neurons]
        
        return w1, b1, w2

class l1_pruning(pruning_base):
    def prune_layer(self, i, layer1, layer2, print_upper_bound=False, criterion='random'):
        vectors = torch.cat((torch.reshape(layer1.weight, (self.n, -1)), torch.reshape(layer1.bias, (self.n, -1)), torch.reshape(torch.transpose(layer2.weight, 0, 1), (self.n, -1))), dim = 1)
        neurons = list(np.argsort(torch.norm(vectors, p=1, dim=1).cpu().numpy())[(self.n - self.K) : self.n])

        w1, b1, w2 = layer1.weight[neurons, :], layer1.bias[neurons], layer2.weight[:, neurons]
        
        return w1, b1, w2


def weight_rescale(half_model_original, half_model_compressed, dataset, 
                    w2_orig, w2_comp, n, m, K):
    '''
        Used for weight rescaling of the weights of the compressed model.
        ONLY for ThiNet algorithm!!
    '''
    D = dataset.shape[0]

    random_channel = list(np.random.choice(range(m), D))

    with torch.no_grad():
        # Use first layers to get represantation before FC of original network
        x_original = half_model_original(dataset)
        w_original = w2_orig[random_channel, :]
        y_hat = torch.sum(x_original * w_original, dim=[1])

        # Use first layers to get representation before FC of altered network
        x_compressed = half_model_compressed(dataset)
        w_compressed = w2_comp[random_channel, :]
        x_hat = x_compressed * w_compressed

    # Scale second linear layer with least squares solution
    w_scale =  torch.linalg.pinv(x_hat) @ y_hat
    w2_new = w_scale.T * w2_comp

    return w2_new 


def pruning_experiment(model, train_loader, val_loader, test_loader, criterion, info,
                       method = 'zonotope_kmeans', print_upper_bound=False,
                       fine_tune=False, dataset=None, w2_rescale=False):
                       
    with open('results_{}_{}_{}.txt'.format(info['name'], info['dataset'], info['imsize']),'a') as txt:
        print('Executing method {} on model {} and dataset {}'.format(method, info['name'], info['dataset']))
        txt.write('Executing method {} on model {} and dataset {}\n'.format(method, info['name'], info['dataset']))

        for f in info['ratios']:

            #initializing best accuracy
            best_accuracy = 0

            # in some experiments we repeat multiple times the random algorithm to ensure better compression
            for _ in range(info['repetitions'] if (method not in ['l1_structured', 'thinet'] ) else 1):
                compressed = copy.deepcopy(model)

                if (f != 1 or print_upper_bound):

                    if (info['name'] == 'CIFAR-VGG'):
                        with torch.no_grad():
                            compressed.classifier[2] = nn.BatchNorm1d(int(f * 512))

                    if (method == 'zonotope_kmeans'):
                        algorithm = zonotope_kmeans_pruning(compressed, info, dataset, method)
                      
                    elif (method == 'neural_path_kmeans'):
                        algorithm = neural_path_kmeans_pruning(compressed, info, dataset, method)

                    elif (method == 'thinet'):
                        algorithm = thinet_pruning(compressed, info, dataset, method)
                    
                    elif (method == 'random_structured'):
                        algorithm = random_pruning(compressed, info, dataset, method)

                    elif (method == 'l1_structured'):
                        algorithm = random_pruning(compressed, info, dataset, method)

                    # Compress given network
                    algorithm.prune_net(f, w2_rescale=w2_rescale, print_upper_bound=print_upper_bound)

                    # move compressed model to GPU
                    compressed.to(device)

                    val_acc = eval(compressed, val_loader, criterion)

                    # we choose the model with best validation accuracy
                    if val_acc > best_accuracy:
                        best_model = compressed
                        best_accuracy = val_acc
                    print('Ratio, Val accuracy: {:.3f} {:.2f}'.format(f, 100 * best_accuracy))
                    txt.write('Ratio, Val accuracy: {:.3f} {:.2f}\n'.format(f, 100 * best_accuracy))

            if (f == 1):
                accuracy = eval(model, test_loader, criterion)
            else:
                accuracy = eval(best_model, test_loader, criterion)
            
            print('Ratio, Test accuracy: {:.3f} {:.2f}'.format(f, 100 * accuracy))
            txt.write('Ratio, Test accuracy: {:.3f} {:.2f}\n'.format(f, 100 * accuracy))

