from abc import ABC, abstractmethod
import numpy as np
import copy
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
from tqdm import tqdm
import random

from training import train, eval
from algorithms import *

device = 'cuda' if torch.cuda.is_available() else 'cpu'

class pruning_base(ABC):
    def __init__(self, model, info, dataset, method='neural_path_kmeans', num_epochs=0):
        self.model = model
        self.info = {'name' : info['name'], 'dataset' : info['dataset'], 'imsize' : info['imsize'], 'method' : method, 'num_epochs' : num_epochs}
        self.dataset = dataset
        
    @abstractmethod
    def prune_layer(self, i, layer1, layer2, print_upper_bound, old_layer1):
        pass
    @abstractmethod
    def prune_cnn_layer(self, i, layer1, layer2, print_upper_bound, old_layer1):
        pass

    def prune_net(self, f, w2_rescale=False, normalize = False, print_upper_bound=False):
        if self.info['name'] == "deepCNN2D" or self.info['name'] == "ConvCIFAR-VGG":
            self.layers = [layer for layer in self.model.features if type(layer) == nn.Conv2d]
            self.layer_indices = [i for i in range(len(self.model.features)) if type(self.model.features[i]) == nn.Conv2d]
        else:    
            self.layers = [layer for layer in self.model.classifier if type(layer) == nn.modules.linear.Linear]
            self.layer_indices = [i for i in range(len(self.model.classifier)) if type(self.model.classifier[i]) == nn.modules.linear.Linear]

        for i in range(len(self.layers) - 1):

            layer1 = self.layers[i]
            layer2 = self.layers[i + 1]

            # old layer 1 to build clustering vectors
            old_layer1 = copy.deepcopy(layer1)

            self.n = layer1.weight.shape[0] # hidden layer dimension
            self.d = layer1.weight.shape[1] # input dimension
            self.m = layer2.weight.shape[0] # output dimension
            self.K = int(f * self.n) # reduced hidden layer dimension
            self.f = f
            
            if self.info['name'] == "deepCNN2D" or self.info['name'] == "ConvCIFAR-VGG":
                w1, b1, w2 = self.prune_cnn_layer(i, layer1, layer2, print_upper_bound)
            else:
                w1, b1, w2 = self.prune_layer(i, layer1, layer2, print_upper_bound)

            if (w2_rescale):
                half_model_original, half_model_compressed = self.half_model(self.model, self.layer_indices, i, w1, b1)
                
                w2 = weight_rescale(half_model_original, half_model_compressed, self.dataset, 
                            layer2.weight, w2, self.n, self.m, self.K)

            self.substitute_weights(f, layer1, layer2, w1, b1, w2)

    def substitute_weights(self, f, layer1, layer2, w1, b1, w2):
        with torch.no_grad():
            if self.info['name'] == 'deepCNN2D' or self.info['name'] == "ConvCIFAR-VGG":
                layer1.out_channels = int(f * layer1.out_channels)
                layer2.in_channels = int(f * layer2.in_channels)
            else:
                layer1.out_features = int(f * layer1.out_features)
                layer2.in_features = int(f * layer2.in_features)

            layer1.weight = nn.Parameter(w1.float())
            layer1.bias = nn.Parameter(b1.float())
            layer2.weight = nn.Parameter(w2.float())


    def half_model(self, model, layer_indices, i, w1_comp, b1_comp):
        if self.info['name'] == "ConvCIFAR-VGG":
            half_model_original = nn.Sequential()
            for j in range(layer_indices[i] + 5):
                if j > layer_indices[i] and type(model.features[j] == nn.Conv2d) or j >= len(model.features):
                    break
                half_model_original.add_module('layer_{}'.format(j), model.features[j])
            
            half_model_compressed = copy.deepcopy(half_model_original)
            half_model_compressed[layer_indices[i]].weight = nn.Parameter(w1_comp.float())
            half_model_compressed[layer_indices[i]].bias = nn.Parameter(b1_comp.float())
            
            half_model_original.to(device)
            half_model_compressed.to(device)

            return half_model_original, half_model_compressed
        
        if self.info['name'] == "deepCNN2D":
            half_model_original = nn.Sequential()
            for j in range(layer_indices[i] + 3):
                half_model_original.add_module('layer_{}'.format(j), model.features[j])
            
            half_model_compressed = copy.deepcopy(half_model_original)
            half_model_compressed[layer_indices[i]].weight = nn.Parameter(w1_comp.float())
            half_model_compressed[layer_indices[i]].bias = nn.Parameter(b1_comp.float())
            
            half_model_original.to(device)
            half_model_compressed.to(device)

            return half_model_original, half_model_compressed

        half_model_original = copy.deepcopy(model.features)
        if self.info['name'] == "AlexNet":
            half_model_original.add_module('avgpool', nn.AdaptiveAvgPool2d(output_size=(1,1)))
        half_model_original.add_module('flatten', nn.Flatten())
        for j in range(layer_indices[i] + 2):
            half_model_original.add_module('layer_{}'.format(j), model.classifier[j])
        
        half_model_compressed = copy.deepcopy(half_model_original)
        if self.info['name'] == "AlexNet":
            half_model_compressed[len(model.features) + 2 + layer_indices[i]].weight = nn.Parameter(w1_comp.float())
            half_model_compressed[len(model.features) + 2 + layer_indices[i]].bias = nn.Parameter(b1_comp.float())
        else:
            half_model_compressed[len(model.features) + 1 + layer_indices[i]].weight = nn.Parameter(w1_comp.float())
            half_model_compressed[len(model.features) + 1 + layer_indices[i]].bias = nn.Parameter(b1_comp.float())

        half_model_original.to(device)
        half_model_compressed.to(device)
        
        return half_model_original, half_model_compressed

class zonotope_kmeans_pruning(pruning_base):
    def prune_layer(self, i, layer1, layer2, print_upper_bound=False, old_layer1=None):
        points = torch.cat((layer1.weight, layer1.bias.reshape(-1, 1)), dim=1)
        coef = layer2.weight.flatten()
            
        pos_points = [(points[i] * coef[i]).tolist() for i in range(points.shape[0]) if coef[i] > 0]
        neg_points = [(points[i] * -coef[i]).tolist() for i in range(points.shape[0]) if coef[i] < 0]
        
        pos_weights, pos_bias = zonotope_kmeans(pos_points, self.f, self.info, print_upper_bound=print_upper_bound)
        neg_weights, neg_bias = zonotope_kmeans(neg_points, self.f, self.info, print_upper_bound=print_upper_bound)

        w1 = torch.cat((pos_weights, neg_weights), dim=0)
        b1 = torch.cat((pos_bias, neg_bias), dim=0)
        w2 = torch.tensor([[1.0 for i in range(pos_weights.shape[0])] + [-1.0 for i in range(neg_weights.shape[0])]])

        return w1, b1, w2
    
    def prune_cnn_layer(self, i, layer1, layer2, print_upper_bound, old_layer1=None):
        print("ERROR. Shouldn't use zonotope kmeans with convolutional networks")
        pass

class improved_zonotope_kmeans_pruning(pruning_base):
    def prune_layer(self, i, layer1, layer2, print_upper_bound=False, old_layer1=None):
        points = torch.cat((layer1.weight, layer1.bias.reshape(-1, 1)), dim=1)
        coef = layer2.weight.flatten()
            
        pos_points = [(points[i] * coef[i]).tolist() for i in range(points.shape[0]) if coef[i] > 0]
        neg_points = [(points[i] * -coef[i]).tolist() for i in range(points.shape[0]) if coef[i] < 0]
        
        pos_weights, pos_bias = improved_zonotope_kmeans(pos_points, self.f, self.info, print_upper_bound=print_upper_bound)
        neg_weights, neg_bias = improved_zonotope_kmeans(neg_points, self.f, self.info, print_upper_bound=print_upper_bound)

        w1 = torch.cat((pos_weights, neg_weights), dim=0)
        b1 = torch.cat((pos_bias, neg_bias), dim=0)
        w2 = torch.tensor([[1.0 for i in range(pos_weights.shape[0])] + [-1.0 for i in range(neg_weights.shape[0])]])

        return w1, b1, w2
    
    def prune_cnn_layer(self, i, layer1, layer2, print_upper_bound, old_layer1=None):
        print("ERROR. Shouldn't use single output algorithm for convolutinal networks. ")
        pass


class neural_path_kmeans_pruning(pruning_base):
    def prune_layer(self, i, layer1, layer2, print_upper_bound=False, old_layer1=None):
        points = torch.cat((layer1.weight, layer1.bias.reshape(-1, 1), torch.transpose(layer2.weight, 0, 1)), dim=1)
        w1, b1, w2 = neural_path_kmeans(points.cpu().detach(), self.d, self.m, self.f, self.info, print_upper_bound=print_upper_bound)
        w2 = torch.transpose(w2, 0, 1).float()

        return w1, b1, w2
    
    def prune_cnn_layer(self, i, layer1, layer2, print_upper_bound=False, old_layer1=None):
        in_shape = layer1.weight.shape
        out_shape = layer2.weight.shape
        new_d = in_shape[1] * in_shape[2] * in_shape[3]
        new_m = out_shape[0] * out_shape[2] * out_shape[3]
        points = torch.cat((layer1.weight.flatten(1), layer1.bias.reshape(-1, 1), torch.transpose(layer2.weight, 0, 1).flatten(1)), dim=1)
        w1, b1, w2 = neural_path_kmeans(points.cpu().detach(), new_d, new_m, self.f, self.info, print_upper_bound=print_upper_bound)
        
        w1 = w1.reshape(w1.shape[0], in_shape[1], in_shape[2], in_shape[3])
        b1 = b1.reshape(-1)
        w2 = w2.reshape(w2.shape[0], out_shape[0], out_shape[2], out_shape[3])

        w2 = torch.transpose(w2, 0, 1).float()

        return w1, b1, w2
    
class tropnnc_pruning(pruning_base):
    def prune_layer(self, i, layer1, layer2, print_upper_bound=False, old_layer1=None):
        points = torch.cat((layer1.weight, layer1.bias.reshape(-1, 1), torch.transpose(layer2.weight, 0, 1)), dim=1)
        w1, b1, w2 = tropnnc(points.cpu().detach(), self.d, self.m, self.f, self.info, print_upper_bound=print_upper_bound)
        w2 = torch.transpose(w2, 0, 1).float()

        return w1, b1, w2
    
    def prune_cnn_layer(self, i, layer1, layer2, print_upper_bound=False, old_layer1=None):
        if old_layer1 == None:
            old_layer1 = layer1
        in_shape = layer1.weight.shape
        out_shape = layer2.weight.shape
        new_d = in_shape[1] * in_shape[2] * in_shape[3]
        new_m = out_shape[0] * out_shape[2] * out_shape[3]
        features = torch.cat((old_layer1.weight.flatten(1), old_layer1.bias.reshape(-1, 1), torch.transpose(layer2.weight, 0, 1).flatten(1)), dim=1)
        points = torch.cat((layer1.weight.flatten(1), layer1.bias.reshape(-1, 1), torch.transpose(layer2.weight, 0, 1).flatten(1)), dim=1)
        w1, b1, w2 = tropnnc(points.cpu().detach(), new_d, new_m, self.f, self.info, print_upper_bound=print_upper_bound, features=features.cpu().detach())
        
        w1 = w1.reshape(w1.shape[0], in_shape[1], in_shape[2], in_shape[3])
        b1 = b1.reshape(-1)
        w2 = w2.reshape(w2.shape[0], out_shape[0], out_shape[2], out_shape[3])

        w2 = torch.transpose(w2, 0, 1).float()

        return w1, b1, w2

class thinet_pruning(pruning_base):
    def prune_layer(self, i, layer1, layer2, print_upper_bound=False, old_layer1=None):
        half_model, _ = self.half_model(self.model, self.layer_indices, i, layer1.weight, layer1.bias)
        w1, b1, w2 = thinet(half_model, self.dataset, layer1, layer2, self.K)
        return w1, b1, w2
    def prune_cnn_layer(self, i, layer1, layer2, print_upper_bound=False, old_layer1=None):
        half_model, _ = self.half_model(self.model, self.layer_indices, i, layer1.weight, layer1.bias)
        w1, b1, w2 = thinet(half_model, self.dataset, layer1, layer2, self.K)
        return w1, b1, w2

class random_pruning(pruning_base):
    def prune_layer(self, i, layer1, layer2, print_upper_bound=False, old_layer1=None, criterion='random'):
        neurons = random.sample(range(self.n), self.K)

        w1, b1, w2 = layer1.weight[neurons, :], layer1.bias[neurons], layer2.weight[:, neurons]
        
        return w1, b1, w2
    def prune_cnn_layer(self, i, layer1, layer2, print_upper_bound=False, old_layer1=None, criterion='random'):
        neurons = random.sample(range(self.n), self.K)

        w1, b1, w2 = layer1.weight[neurons, :], layer1.bias[neurons], layer2.weight[:, neurons]
        
        return w1, b1, w2

class l1_pruning(pruning_base):
    def prune_layer(self, i, layer1, layer2, print_upper_bound=False, old_layer1=None, criterion='random'):
        vectors = torch.cat((torch.reshape(layer1.weight, (self.n, -1)), torch.reshape(layer1.bias, (self.n, -1)), torch.reshape(torch.transpose(layer2.weight, 0, 1), (self.n, -1))), dim = 1)
        neurons = list(np.argsort(torch.norm(vectors, p=1, dim=1).cpu().detach().numpy())[(self.n - self.K) : self.n])

        w1, b1, w2 = layer1.weight[neurons, :], layer1.bias[neurons], layer2.weight[:, neurons]
        
        return w1, b1, w2
    def prune_cnn_layer(self, i, layer1, layer2, print_upper_bound=False, old_layer1=None, criterion='random'):
        if old_layer1 == None:
            old_layer1 = layer1
        vectors = torch.cat((torch.reshape(old_layer1.weight, (self.n, -1)), torch.reshape(old_layer1.bias, (self.n, -1)), torch.reshape(torch.transpose(layer2.weight, 0, 1), (self.n, -1))), dim = 1)
        neurons = list(np.argsort(torch.norm(vectors, p=1, dim=1).cpu().detach().numpy())[(self.n - self.K) : self.n])

        w1, b1, w2 = layer1.weight[neurons, :], layer1.bias[neurons], layer2.weight[:, neurons]
        
        return w1, b1, w2


def weight_rescale(half_model_original, half_model_compressed, dataset, 
                    w2_orig, w2_comp, n, m, K):
    '''
        Used for weight rescaling of the weights of the compressed model.
        ONLY for ThiNet algorithm!!
    '''
    D = dataset.shape[0]

    random_channel = list(np.random.choice(range(m), D))

    with torch.no_grad():
        # Use first layers to get represantation before FC of original network
        # print(half_model_original)
        x_original = half_model_original(dataset)
        w_original = w2_orig[random_channel, :]
        # Create x_hat tensor
        d = dataset.shape[0]
        x_hat_original = torch.zeros((d, n)).to(device)

        # Use first layers to get representation before FC of altered network
        # print(half_model_compressed)
        x_compressed = half_model_compressed(dataset)
        w_compressed = w2_comp[random_channel, :]
        # Create x_hat tensor
        x_hat = torch.zeros((d, K)).to(device)

        if len(x_original.shape) > 2:
            # Generate random coordinates
            i_coords = np.random.randint(low=w_original.size(2)//2, high=x_original.shape[2]+w_original.size(2)//2, size=d)
            j_coords = np.random.randint(low=w_original.size(3)//2, high=x_original.shape[3]+w_original.size(3)//2, size=d)

            # Convert to tensors
            i_coords = torch.tensor(i_coords, dtype=torch.long)
            j_coords = torch.tensor(j_coords, dtype=torch.long)

            # Vectorized computation
            for sample in range(d):
                i = i_coords[sample]
                j = j_coords[sample]
                
                # Get the sub-images around (i, j) with padding if necessary
                image = x_original[sample]
                image = nn.functional.pad(image, (w_original.size(2)//2, w_original.size(2)//2, w_original.size(3)//2, w_original.size(3)//2), mode='constant', value = 0)
                kernel = w_original[sample]
                
                # Use unfold to get patches
                patches = image
                # patches = image.unfold(1, kernel.size(1), 1).unfold(2, kernel.size(2), 1)
                patches = patches[:, i - kernel.size(1) // 2:i + kernel.size(1) // 2 + 1, j - kernel.size(2) // 2:j + kernel.size(2) // 2 + 1]
                
                # Multiply and sum to get the result for the specific (i, j)
                x_hat_original[sample] = (patches * kernel).sum(dim=[1, 2])
            
            # Vectorized computation
            for sample in range(d):
                i = i_coords[sample]
                j = j_coords[sample]
                
                # Get the sub-images around (i, j) with padding if necessary
                image = x_compressed[sample]
                image = nn.functional.pad(image, (w_compressed.size(2)//2, w_compressed.size(2)//2, w_compressed.size(3)//2, w_compressed.size(3)//2), mode='constant', value = 0)
                kernel = w_compressed[sample]
                
                # Use unfold to get patches
                patches = image
                # patches = image.unfold(1, kernel.size(1), 1).unfold(2, kernel.size(2), 1)
                patches = patches[:, i - kernel.size(1) // 2:i + kernel.size(1) // 2 + 1, j - kernel.size(2) // 2:j + kernel.size(2) // 2 + 1]
                
                # Multiply and sum to get the result for the specific (i, j)
                x_hat[sample] = (patches * kernel).sum(dim=[1, 2])
        else:
            x_hat_original = x_original * w_original
            x_hat = x_compressed * w_compressed
        y_hat = torch.sum(x_hat_original, dim=[1])

    # Scale second linear layer with least squares solution
    w_scale = torch.linalg.pinv(x_hat) @ y_hat

    if len(w2_comp.shape) > 2:
        w_scale = w_scale.view(1, w_scale.shape[0], 1, 1)
        w2_new = w_scale * w2_comp
    else:
        w2_new = w_scale.T * w2_comp
    
    return w2_new 

import torchvision
def print_model_param_nums(model=None):
    if model == None:
        model = torchvision.models.alexnet()
    total = sum([param.nelement() if param.requires_grad else 0 for param in model.parameters()])
    print('  + Number of params: %.2fM' % (total / 1e6))
    return total

def pruning_experiment(model, train_loader, val_loader, test_loader, criterion, info,
                       method = 'zonotope_kmeans', print_upper_bound=False,
                       fine_tune=False, dataset=None, w2_rescale=False, normalize=False, print_n_params=False):
                       
    with open('results_{}_{}_{}.txt'.format(info['name'], info['dataset'], info['imsize']),'a') as txt:
        print('Executing method {} on model {} and dataset {}'.format(method, info['name'], info['dataset']))
        txt.write('Executing method {} on model {} and dataset {}\n'.format(method, info['name'], info['dataset']))

        for f in info['ratios']:

            #initializing best accuracy
            best_accuracy = 0

            # in some experiments we repeat multiple times the random algorithm to ensure better compression
            for _ in range(info['repetitions'] if (method not in ['l1_structured'] ) else 1):
                if dataset != None and method == "thinet":
                    # Randomly select 5,000 indices
                    subset_size = 5000
                    # subset_size = 1000
                    indices = torch.randperm(len(dataset))[:subset_size]

                    # Create the subset
                    subset = torch.utils.data.Subset(dataset, indices)

                    # Collect all elements of the subset into single tensors
                    all_images = []

                    for p in subset:
                        img = p[0]
                        all_images.append(img)

                    # Stack them into a single tensor
                    all_images_tensor = torch.stack(all_images)
                    
                    subset = all_images_tensor.to(device)
                else:
                    subset = None

                compressed = copy.deepcopy(model)

                if (f != 1 or print_upper_bound):

                    if (info['name'] == 'CIFAR-VGG'):
                        with torch.no_grad():
                            compressed.classifier[2] = nn.BatchNorm1d(int(f * 512))

                    if (method == 'zonotope_kmeans'):
                        algorithm = zonotope_kmeans_pruning(compressed, info, subset, method)
                    elif (method == 'improved_zonotope_kmeans'):
                        algorithm = improved_zonotope_kmeans_pruning(compressed, info, subset, method)
                    
                    elif (method == 'neural_path_kmeans'):
                        algorithm = neural_path_kmeans_pruning(compressed, info, subset, method)
                    elif (method == 'tropnnc'):
                        algorithm = tropnnc_pruning(compressed, info, subset, method)
                    elif (method == 'iterative_tropnnc'):
                        algorithm = tropnnc_pruning(compressed, info, subset, method, num_epochs=3)
                    elif (method == 'normalized_tropnnc'):
                        algorithm = tropnnc_pruning(compressed, info, subset, method)

                    elif (method == 'thinet'):
                        algorithm = thinet_pruning(compressed, info, subset, method)
                    
                    elif (method == 'random_structured'):
                        algorithm = random_pruning(compressed, info, dataset, method)

                    elif (method == 'l1_structured'):
                        algorithm = l1_pruning(compressed, info, dataset, method)

                    # Compress given network
                    algorithm.prune_net(f, w2_rescale=w2_rescale, normalize=normalize, print_upper_bound=print_upper_bound)

                    # move compressed model to GPU
                    compressed.to(device)

                    val_acc = eval(compressed, val_loader, criterion)

                    # we choose the model with best validation accuracy
                    if val_acc > best_accuracy:
                        best_model = compressed
                        best_accuracy = val_acc
                    print('Ratio, Val accuracy: {:.3f} {:.2f}'.format(f, 100 * best_accuracy))
                    txt.write('Ratio, Val accuracy: {:.3f} {:.2f}\n'.format(f, 100 * best_accuracy))

            if (f == 1):
                accuracy = eval(model, test_loader, criterion)
                if print_n_params:
                    n_params = print_model_param_nums(model)
            else:
                accuracy = eval(best_model, test_loader, criterion)
                if print_n_params:
                    n_params = print_model_param_nums(best_model)
            
            print('Ratio, Test accuracy: {:.3f} {:.2f}'.format(f, 100 * accuracy))
            txt.write('Ratio, Test accuracy: {:.3f} {:.2f}\n'.format(f, 100 * accuracy))

            if print_n_params:
                print('Ratio, Test params: {:.3f} {:.2f}'.format(f, n_params))
                txt.write('Ratio, Test params: {:.3f} {:.2f}\n'.format(f, n_params))


