import numpy as np
from sklearn.cluster import KMeans
import torch
import torch.nn as nn

from numpy import linalg as LA
from tqdm import tqdm

import copy

device = 'cuda' if torch.cuda.is_available() else 'cpu'

### An GPU implementation of KMeans which runs faster than sklearn. However, it lead to slighty
### worse results, and thus we do not use it. 
# class KMeans:
#     def __init__(self, n_clusters=8, max_iter=300, tol=1e-4, random_state=None, n_init=10, device=None):
#         self.n_clusters = n_clusters
#         self.max_iter = max_iter
#         self.tol = tol
#         self.random_state = random_state
#         self.n_init = n_init
#         self.device = device if device else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#         self.cluster_centers_ = None
#         self.labels_ = None
#         self.inertia_ = None

#     def fit(self, X):
#         X = X.to(self.device)
#         best_inertia = float('inf')
#         best_centers = None
#         best_labels = None

#         for _ in range(self.n_init):
#             inertia, centers, labels = self._run_kmeans(X)
#             if inertia < best_inertia:
#                 best_inertia = inertia
#                 best_centers = centers
#                 best_labels = labels

#         self.cluster_centers_ = best_centers.cpu().numpy()
#         self.labels_ = best_labels.cpu().numpy()
#         self.inertia_ = best_inertia
#         return self

#     def _run_kmeans(self, X):
#         n_samples, n_features = X.shape

#         if self.random_state:
#             torch.manual_seed(self.random_state)

#         indices = torch.randperm(n_samples)[:self.n_clusters].to(self.device)
#         centers = X[indices]

#         for i in range(self.max_iter):
#             distances = torch.cdist(X, centers)
#             labels = torch.argmin(distances, dim=1)
#             new_centers = torch.stack([X[labels == j].mean(dim=0) for j in range(self.n_clusters)])

#             shift = torch.norm(centers - new_centers)
#             if shift < self.tol:
#                 break

#             centers = new_centers

#         inertia = self._compute_inertia(X, labels, centers)
#         return inertia, centers, labels

#     def predict(self, X):
#         X = torch.tensor(X, dtype=torch.float32).to(self.device)
#         distances = torch.cdist(X, torch.tensor(self.cluster_centers_, dtype=torch.float32).to(self.device))
#         return torch.argmin(distances, dim=1).cpu().numpy()

#     def fit_predict(self, X):
#         self.fit(X)
#         return self.labels_

#     def _compute_inertia(self, X, labels, centers):
#         inertia = 0.0
#         for i in range(self.n_clusters):
#             cluster_points = X[labels == i]
#             inertia += torch.sum(torch.norm(cluster_points - centers[i], dim=1) ** 2)
#         return inertia.item()


def zonotope_kmeans(points, f, info, print_upper_bound=False):
    '''
        Implementation of Zonotope K-means algorithm
    '''
    n_centers = int(f * len(points))
    
    if(n_centers != 0):
        kmeans = KMeans(n_clusters=n_centers, n_init=10).fit(points)
        
        weights = [x[:-1] for x in kmeans.cluster_centers_]
        bias = [x[-1] for x in kmeans.cluster_centers_]

    else: 
        weights = []
        bias = []

    return torch.tensor(np.array(weights)), torch.tensor(np.array(bias))

def improved_zonotope_kmeans(points, f, info, print_upper_bound=False):
    '''
        Implementation of single-output TropNNC
    '''
    n_centers = int(f * len(points))
    
    if(n_centers != 0):
        kmeans = KMeans(n_clusters=n_centers, n_init=10).fit(points)

        centers = torch.tensor(kmeans.cluster_centers_)
        points = torch.tensor(points)

        # Compute the sum of the vectors in each cluster
        cluster_sums = torch.zeros_like(centers)
        for k in range(n_centers):
            cluster_sums[k] = points[kmeans.labels_ == k].sum(dim=0)

        # return cluster_sums
        
        weights = [x[:-1].numpy() for x in cluster_sums]
        bias = [x[-1].numpy() for x in cluster_sums]

    else: 
        weights = []
        bias = []

    return torch.tensor(np.array(weights)), torch.tensor(np.array(bias))

#### Simple implementation of the iterative algorithm, which are slow but easy to understand

# def solve_for_c(points, res, d, m):
#     points_mean_norm = torch.norm(res[:d+1])
#     for j in range(m):
#         sum = torch.zeros((d+1, )).double()
#         for i in range(len(points)):
#             cji = points[i,d+1+j]
#             sum += cji * points[i, :d+1]
#             res[d+1+j] = torch.dot(sum, res[:d+1])/points_mean_norm
#     return res

# def solve_for_a_b(points, res, d, m):
#     sum_cjk_2 = torch.sum(res[d+1:]**2)
#     sum = torch.zeros((d+1, ))
#     for i in range(len(points)):
#         a_b_i = points[i, :d+1]
#         temp = 0
#         for j in range(m):
#             cji = points[i, d+1+j]
#             cjk = res[d+1+j]
#             temp += cji * cjk
#         sum += temp * a_b_i
#     res[:d+1] = sum / sum_cjk_2
#     return res

# def solve_approx(points, res, d, m, num_epochs):
#     for i in range(num_epochs):
#         res = solve_for_c(points, res, d, m)
#         res = solve_for_a_b(points, res, d, m)
#     return res

#### Faster implementations of the iterative algorithm, which use tensor operations
#### and are more complex

# def solve_for_c(points, res, d, m):
#     points_mean_norm = torch.norm(res[:d+1])
#     for j in range(m):
#         cji = points[:, d+1+j]
#         sum = torch.sum((cji.unsqueeze(1) * points[:, :d+1]), dim=0).double()
#         res[d+1+j] = torch.dot(sum, res[:d+1]) / (points_mean_norm**2)
#     return res

def solve_for_c(points, res, d, m):
    points_mean_norm = torch.norm(res[:d+1])
    points_subset = points[:, :d+1]
    cji_matrix = points[:, d+1:d+1+m]
    sum_matrix = torch.matmul(cji_matrix.T, points_subset).double()
    res[d+1:d+1+m] = torch.matmul(sum_matrix, res[:d+1]) / (points_mean_norm**2)
    return res

def solve_for_a_b(points, res, d, m):
    sum_cjk_2 = torch.sum(res[d+1:]**2)
    cjk = res[d+1:d+1+m]
    cji = points[:, d+1:d+1+m]
    temp = torch.sum(cji * cjk, dim=1)
    weighted_sum = (points[:, :d+1].T * temp).T.sum(dim=0)
    res[:d+1] = weighted_sum / sum_cjk_2
    return res

def solve_approx(points, res, d, m, num_epochs):
    points = points.to(device)
    res = res.to(device)
    for _ in range(num_epochs):
        res = solve_for_c(points, res, d, m)
        res = solve_for_a_b(points, res, d, m)
    return res

def tropnnc(points, d, m, f, info, print_upper_bound=False, features=None):
    '''
        Implementation of TropNNC
    '''
    if features == None:
        features = points
    n_centers = int(f * len(points))
    kmeans = KMeans(n_clusters=n_centers, n_init=10).fit(features)

    centers = torch.tensor(kmeans.cluster_centers_).double()
    if type(points) != torch.Tensor:
        points = torch.tensor(points)

    # Compute the sum of the vectors in each cluster
    cluster_sums = torch.zeros_like(centers)
    for k in range(n_centers):
        cluster_sums[k] = points[kmeans.labels_ == k].sum(dim=0)

    # return cluster_sums

    cluster_means = torch.zeros_like(centers)
    for k in range(n_centers):
        cluster_means[k] = points[kmeans.labels_ == k].mean(dim=0)
        
    final_clusters = torch.cat([cluster_means[:, :d+1], cluster_sums[:, d+1:]], dim = 1)

    for k in range(n_centers):
        cluster_points = points[kmeans.labels_ == k]
        final_clusters[k] = solve_approx(cluster_points, final_clusters[k], d, m, num_epochs=info['num_epochs'])

    # return final_clusters

    a_i = [x[:d].numpy() for x in final_clusters]
    b_i = [x[d].numpy() for x in final_clusters]
    c_i = [x[d + 1 : (d + m + 1)].numpy() for x in final_clusters]

    return torch.tensor(np.array(a_i)).to(device), torch.tensor(np.array(b_i)).to(device), torch.tensor(np.array(c_i)).to(device)

def neural_path_kmeans(points, d, m, f, info, print_upper_bound=False):
    '''
        Implementation of Neural Path K-means algorithm
    '''
    n_centers = int(f * len(points))
    kmeans = KMeans(n_clusters=n_centers, n_init=10).fit(points)

    a_i = [x[:d] for x in kmeans.cluster_centers_]
    b_i = [x[d] for x in kmeans.cluster_centers_]
    c_i = [x[d + 1 : (d + m + 1)] for x in kmeans.cluster_centers_]

    return torch.tensor(np.array(a_i)).to(device), torch.tensor(np.array(b_i)).to(device), torch.tensor(np.array(c_i)).to(device)


def thinet(half_model, dataset, layer1, layer2, K):
    '''
    Implementation of ThiNet (Luo et al., 2017) .
    '''

    w1 = layer1.weight
    b1 = layer1.bias
    w2 = layer2.weight

    n = w1.shape[0] # hidden layer size in neurons/channels
    m = w2.shape[0] # output size in neurons/channels
    d = dataset.shape[0] # number of dataset samples
    T = n - K # number of discarded neurons/channels

    sample_indexes = list(np.random.choice(range(m), d))
    x = half_model(dataset)
    w = w2[sample_indexes, :]

    # Create x_hat tensor
    x_hat = torch.zeros((d, n))
    # x_hat = torch.zeros((5*d, n))

    if len(x.shape) > 2:
        # Generate random coordinates
        i_coords = np.random.randint(low=w.size(3)//2, high=x.shape[2]+w.size(3)//2, size=d)
        j_coords = np.random.randint(low=w.size(3)//2, high=x.shape[3]+w.size(3)//2, size=d)

        # Convert to tensors
        i_coords = torch.tensor(i_coords, dtype=torch.long)
        j_coords = torch.tensor(j_coords, dtype=torch.long)

        # Vectorized computation
        for sample in range(d):
        # for iter in range(5*d):
            i = i_coords[sample]
            j = j_coords[sample]
            # i = i_coords[iter]
            # j = j_coords[iter]
            # sample = iter//5
            
            # Get the sub-images around (i, j) with padding if necessary
            image = x[sample]
            image = nn.functional.pad(image, (w.size(2)//2, w.size(2)//2, w.size(3)//2, w.size(3)//2), mode='constant', value = 0)
            kernel = w[sample]
            
            # Use unfold to get patches
            patches = image
            # patches = image.unfold(1, kernel.size(1), 1).unfold(2, kernel.size(2), 1)
            patches = patches[:, i - kernel.size(1) // 2:i + kernel.size(1) // 2 + 1, j - kernel.size(2) // 2:j + kernel.size(2) // 2 + 1]
            
            # Multiply and sum to get the result for the specific (i, j)
            x_hat[sample] = (patches * kernel).sum(dim=[1, 2])
    else:
        x_hat = x * w

    # T is empty at the beginning
    run_sum = 0

    # idx contains the indices of all remaining neurons
    idx = list(range(n))

    for _ in range(T):
        # Remove the neuron leading to the smallest objective value in the current iteration
        objective = torch.sum((x_hat[:, idx] + run_sum) ** 2, dim = 0)
        h = idx[torch.argmin(objective)]
        idx.remove(h)

        # update current sum
        run_sum += x_hat[:, h].view(-1, 1)

    # Update first linear layer
    w1_new = w1[idx, :]
    b1_new = b1[idx]
    w2_new = w2[:, idx]
        
    return w1_new, b1_new, w2_new
