import numpy as np
import torch
from numpy import genfromtxt
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
from src.utils import to_one_hot


def load_jester():
    data_1 = genfromtxt('./data/jester-data-1.csv', delimiter=',')
    data_2 = genfromtxt('./data/jester-data-2.csv', delimiter=',')
    data = np.concatenate([data_1, data_2], axis=0)
    data = data[data[:, 0] == 100, 1:]
    assert np.sum(data == 99) == 0
    return data


def cluster_z(data, num_clusters):
    data_torch = torch.from_numpy(data).float().cuda()  # N * d
    kmeans = KMeans(n_clusters=num_clusters, random_state=0).fit(data)
    centers = torch.from_numpy(kmeans.cluster_centers_ > 0).float().cuda()  # M * d
    for _ in range(100):
        dist = -torch.mm(data_torch, centers.t())  # N * M
        label = (dist == torch.min(dist, dim=-1)[0].unsqueeze(-1)).float()  # N * M
        sum_prods = torch.mm(label.t(), data_torch)  # M * d
        mean_prods = sum_prods  # M * d
        centers = (mean_prods > 0).float()  # M * d
    loss = torch.mean(torch.sum(data_torch.clamp(0), dim=-1) - torch.max(torch.mm(data_torch, centers.t()), dim=-1)[0])
    return centers, loss.cpu().numpy()


def get_z(return_data=False):
    data = load_jester()
    z = cluster_z(data, 10)[0]
    print(np.unique(z.cpu().numpy(), axis=0).shape)
    if return_data:
        return z, data
    else:
        return z


def plot_kmeans():
    data = load_jester()
    num_clusters = np.arange(10, 500, 10)
    losses = []
    for n in num_clusters:
        print("Fitting %d Clusters" % n)
        _, loss = cluster_z(data, n)
        losses.append(loss)
    print(losses)
    plt.plot(num_clusters, losses)
    plt.xlabel("Number of Clusters")
    plt.ylabel("Means Squared Distance to Centers")
    plt.show()


if __name__ == "__main__":
    print(get_z(return_data=True)[1].shape)