import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from get_args import get_args
import os
import wandb
from models.MLP import MLP
import numpy as np
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
from data.data_loaders import get_dataloaders

NUM_CLUSTERS=20
EPSILON=0.00001
datasets = [
    "ethiopic",
    "fashion_mnist",
    "kannada",
    "kmnist",
    "mnist",
    "nko",
    "osmanya",
    "vai"
]

def main(args):
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)), transforms.Lambda(lambda x: torch.flatten(x))])
    testset = torchvision.datasets.MNIST(root=args.datadir, train=False, download=True, transform=transform)

    testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)

    net = MLP(
        args.num_hidden_layers, 
        args.width, 
        args.c, 
        args.weight_distribution,
        args.weight_gain,
        args.bias_distribution,
        args.bias_gain,
        args.train_weights,
        args.input_layer_bias,
        args.output_layer_bias,
        args.middle_layers_bias,
        args.l1_weight,
        args.bias_l1_weight,
        args.bias_l1_baseline
        )

    criterion = nn.CrossEntropyLoss()

    net.set_weights(args.load_weights_path)
    variances = {}
    biases = {}
    means = {}
    label_corrs = {}
    for dataset in datasets:
        args.dataset = dataset
        trainloader, testloader = get_dataloaders(args)
        biases_path = args.load_biases_path.replace("mnist", dataset)
        net.set_biases(biases_path)
        val_loss, val_accuracy, inputs, variance, mean, label_corr = net.eval_epoch_with_variances(testloader, criterion)
        variances[dataset] = variance
        biases[dataset] = net.get_biases()
        means[dataset] = mean
        label_corrs[dataset] = label_corr
    
    variances = np.array(list(variances.values()))
    biases = np.array(list(biases.values()))
    means = np.array(list(means.values()))
    label_corrs = np.array(list(label_corrs.values()))
    
    colors = ['b', 'g', 'r', 'c', 'm', 'y', 'k', 'w']
    for i, bias in enumerate(biases):
        variance = variances[i]
        plt.scatter(bias, variance, alpha=0.1, s=0.3, color=colors[i])

    variances = variances.transpose()
    biases = biases.transpose()
    means = means.transpose()
    label_corrs = label_corrs.transpose()

    corr_coefs = np.zeros((len(biases),1))
    for i, bias in enumerate(biases):
        corr_coefs[i] = np.corrcoef(biases[i], variances[i])[0, 1]

    
    normalized_variances = variances / (np.expand_dims(np.max(variances, axis=1), axis=1) + EPSILON)

    kmeans = KMeans(n_clusters=NUM_CLUSTERS)
    kmeans.fit(normalized_variances)
    labels = kmeans.labels_

    clustered_variances = np.zeros_like(variances)
    clustered_normalized_variances = np.zeros_like(normalized_variances)
    clustered_array_biases = np.zeros_like(biases)
    clustered_array_means = np.zeros_like(means)
    clustered_corr_coefs = np.zeros_like(corr_coefs)
    clustered_labels = np.zeros_like(labels)
    clustered_label_corrs = np.zeros_like(label_corrs)

    left_bound = 0
    for cluster in range(NUM_CLUSTERS):
        mask = (labels == cluster)
        num_in_cluster = np.sum(mask)
        clustered_variances[left_bound:left_bound + num_in_cluster] = variances[mask]
        clustered_normalized_variances[left_bound:left_bound + num_in_cluster] = normalized_variances[mask]
        clustered_array_biases[left_bound:left_bound + num_in_cluster] = biases[mask]
        clustered_array_means[left_bound:left_bound + num_in_cluster] = means[mask]
        clustered_corr_coefs[left_bound:left_bound + num_in_cluster] = corr_coefs[mask]
        clustered_labels[left_bound:left_bound + num_in_cluster] = cluster
        clustered_label_corrs[left_bound:left_bound + num_in_cluster] = label_corrs[mask]
        left_bound += num_in_cluster

    
    clustered_variances = clustered_variances.transpose()
    clustered_normalized_variances = clustered_normalized_variances.transpose()
    clustered_array_biases = clustered_array_biases.transpose()
    clustered_array_means = clustered_array_means.transpose()
    clustered_corr_coefs = clustered_corr_coefs.transpose()
    clustered_labels = np.expand_dims(clustered_labels, axis=0)
    clustered_label_corrs = clustered_label_corrs.transpose()

    np.save(f"{args.results_path}/data/clustered_variances.npy", clustered_variances)
    np.save(f"{args.results_path}/data/clustered_normalized_variances.npy", clustered_normalized_variances)
    np.save(f"{args.results_path}/data/clustered_biases.npy", clustered_array_biases)
    np.save(f"{args.results_path}/data/clustered_means.npy", clustered_array_means)
    np.save(f"{args.results_path}/data/clustered_corr_coefs.npy", clustered_corr_coefs)
    np.save(f"{args.results_path}/data/clustered_labels.npy", clustered_labels)
    np.save(f"{args.results_path}/data/clustered_label_corrs.npy", clustered_label_corrs)
    

if __name__ == "__main__":
    args = get_args()
    main(args)