import numpy as np
from sklearn.cluster import KMeans
import torch
import torch.nn as nn

from numpy import linalg as LA
from tqdm import tqdm

device = 'cuda' if torch.cuda.is_available() else 'cpu'


def zonotope_kmeans(points, f, info, print_upper_bound=False):
    '''
        Implementation of Zonotope K-means algorithm
    '''
    n_centers = int(f * len(points))
    
    if(n_centers != 0):
        kmeans = KMeans(n_clusters=n_centers).fit(points)
        
        weights = [x[:-1] for x in kmeans.cluster_centers_]
        bias = [x[-1] for x in kmeans.cluster_centers_]

    else: 
        weights = []
        bias = []

    ############### ONLY FOR UPPER BOUND EXPERIMENT
    if(print_upper_bound):
        cardinalities = np.zeros(n_centers)
        for l in kmeans.labels_:
            cardinalities[l] += 1
        N_max = max(cardinalities)

        d_max = np.max(np.min(kmeans.transform(points), axis=1), axis = 0)

        with open('results_{}_{}_{}.txt'.format(info['name'], info['dataset'], info['imsize']),'a') as txt:
            bound = n_centers * d_max + (1 - 1 / N_max) * sum([LA.norm(point) for point in points])
            print('Ratio, Upper bound for zonotope_kmeans: {:.3f} {:.2f}'.format(f, bound))
            txt.write('Ratio, Upper bound for zonotope_kmeans: {:.3f} {:.2f}\n'.format(f, bound))
    #################################

    return torch.tensor(np.array(weights)), torch.tensor(np.array(bias))


def neural_path_kmeans(points, d, m, f, info, print_upper_bound=False):
    '''
        Implementation of Neural Path K-means algorithm
    '''
    n_centers = int(f * len(points))
    kmeans = KMeans(n_clusters=n_centers).fit(points)

    a_i = [x[:d] for x in kmeans.cluster_centers_]
    b_i = [x[d] for x in kmeans.cluster_centers_]
    c_i = [x[d + 1 : (d + m + 1)] for x in kmeans.cluster_centers_]

    ############### ONLY FOR UPPER BOUND EXPERIMENT
    if(print_upper_bound):

        cardinalities = np.zeros(n_centers)
        for l in kmeans.labels_:
            cardinalities[l] += 1
        N_max = max(cardinalities)
        N_min = min(cardinalities)
        
        d_max = np.max(np.min(kmeans.transform(points), axis=1), axis = 0)

        null_generators = []
        for j in range(m):
            for i in range(len(points)):
                c_ji = points[i][d + 1 + j]
                if c_ji * kmeans.cluster_centers_[kmeans.labels_[i]][d + 1 + j] < 0:
                    null_generators.append(abs(c_ji) * LA.norm(points[i][:(d + 1)]))

        with open('results_{}_{}_{}.txt'.format(info['name'], info['dataset'], info['imsize']),'a') as txt:
            bound = np.sqrt(m) * n_centers * (d_max ** 2) + \
                    np.sqrt(m) * (1 - 1 / N_max) * sum([LA.norm(x[(d + 1) : (d + m + 1)]) * LA.norm(x[:(d + 1)]) for x in points]) + \
                    ((np.sqrt(m) * d_max) / N_min) * sum([LA.norm(x[(d + 1) : (d + m + 1)]) + LA.norm(x[:(d + 1)]) for x in points]) + \
                    sum(null_generators)

            print('Ratio, Upper bound for neural_path_kmeans: {:.3f} {:.2f}'.format(f, bound))
            txt.write('Ratio, Upper bound for neural_path_kmeans: {:.3f} {:.2f}\n'.format(f, bound))
    #########################

    return torch.tensor(np.array(a_i)).to(device), torch.tensor(np.array(b_i)).to(device), torch.tensor(np.array(c_i)).to(device)


def thinet(half_model, dataset, layer1, layer2, K):
    '''
    Implementation of ThiNet (Luo et al., 2017) 
    adapted for FC layers compression.
    '''
    w1 = layer1.weight
    b1 = layer1.bias
    w2 = layer2.weight

    n = w1.shape[0] # hidden layer size in neurons/channels
    m = w2.shape[0] # output size in neurons/channels
    d = dataset.shape[0] # number of dataset samples
    T = n - K # number of discarded neurons/channels

    sample_indexes = list(np.random.choice(range(m), d))

    x = half_model(dataset)
    w = w2[sample_indexes, :]

    x_hat = x * w

    # T is empty at the beginning
    run_sum = 0

    # idx contains the indices of all remaining neurons
    idx = list(range(n))

    for _ in range(T):
        # Remove the neuron leading to the smallest objective value in the current iteration
        objective = torch.sum((x_hat[:, idx] + run_sum) ** 2, dim = 0)
        h = idx[torch.argmin(objective)]
        idx.remove(h)

        # update current sum
        run_sum += x_hat[:, h].view(-1, 1)

    # Update first linear layer
    w1_new = w1[idx, :]
    b1_new = b1[idx]
    w2_new = w2[:, idx]
        
    return w1_new, b1_new, w2_new
