from scipy import cluster
import torchvision.transforms as transforms
import numpy as np
import torch
import torchvision
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from sklearn.manifold import TSNE
from sklearn.cluster import KMeans
import sys
import os
from k_means_constrained import KMeansConstrained

sys.path.append('..') 
from core.utils import fix_all_seed
from FedLab.datasets.pickle_dataset import PickleDataset
    
def divide_dataset_indices(dataset, num_clients, num_chunks=1):
        num_samples = len(dataset)
        num_samples_per_chunk = num_samples // (num_clients * num_chunks)
        # Sorting dataset by label and get indices
        sorted_indices = np.argsort(dataset.targets)
        # Divide indices into num_clients
        chunks = []
        for i in range(num_clients * num_chunks):
            chunks.append(sorted_indices[i*num_samples_per_chunk:min((i+1)*num_samples_per_chunk, len(sorted_indices))])
        np.random.shuffle(chunks)
        client_indices = []
        for i in range(num_clients):
            client_indices.append(np.concatenate(chunks[i*num_chunks:min((i+1)*num_chunks, len(chunks))]))
        return client_indices

def subset_by_indices(dataset, indices_list):
    if isinstance(indices_list, list):
        united_indices = np.concatenate(indices_list)
    else:
        united_indices = indices_list
    return torch.utils.data.Subset(dataset, united_indices)

def dataset_average(dataset):
    dataloader = DataLoader(dataset, batch_size=len(dataset), shuffle=False)
    inputs, targets = next(iter(dataloader))
    # print(inputs.mean(axis=0).shape)
    return inputs.mean(axis=0).squeeze(0)

def clusterize_dataset_kmeans(dataset, n_clusters, seed):
    # Getting features
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=len(dataset), shuffle=False)
    inputs, targets = next(iter(dataloader))
    inputs = inputs.view(inputs.shape[0], -1)
    # Doing clusterization
    clusterization = KMeans(init="k-means++", n_clusters=n_clusters, random_state=seed)
    clusterization.fit(inputs)
    clusters = []
    for i in range(n_clusters):
        clusters.append(np.where(clusterization.labels_ == i)[0])
    return clusters

def cluster_to_clients(cluster_indices, num_clients, seed):
    '''
    Distributes cluster indices to clients. 
    '''
    # Shuffle indices
    np.random.shuffle(cluster_indices, random_state=seed)
    # Split indices into num_clients
    clients_indices = np.array_split(cluster_indices, num_clients)
    return clients_indices
    

class DatasetBase:
    def __init__(self, workers_count, iid_degree, seed):
        self.workers_count = workers_count
        self.iid_degree = iid_degree
        fix_all_seed(seed)

    def __call__(self, train=True):
        raise NotImplementedError
    
    def classes_count(self):
        raise NotImplementedError


class DatasetEMNIST(DatasetBase):
    transform_emnist = transforms.Compose([
        transforms.ToTensor(),  # Convert PIL Image to PyTorch tensor
        transforms.Normalize(mean=0.1736, std=0.3317)
    ])

    def __call__(self, train=True):
        dataset = torchvision.datasets.EMNIST(root='./data', split='balanced', train=train, download=False, transform=self.transform_emnist)
        if train:
            workers_data_indices = divide_dataset_indices(dataset, self.workers_count, num_chunks=2)
        else:
            workers_data_indices = None
        return dataset, workers_data_indices

    def classes_count(self):
        return 47

class DatasetFemnist(DatasetBase):
    '''
    Dataset is based on following generation from FedLab command: 
    cd FedLab/datasets/femnist \ 
    bash ./preprocess.sh -s niid --sf 1.0 -k 0 -t sample --tf 0.9 --smlpseed 42 --spltseed 42
    Using seed 1706121644

    Thus dataset exists from 3597 clients with mean 227.37 and std 88.84 samples per client. 


    '''

    MEAN = 227.37
    CLIENTS_NUMBER = 3597
    def __init__(self, workers_count, iid_degree, seed, n_clusters=1):
        super().__init__(workers_count, iid_degree, seed)
        worker_average_number = int(np.ceil(iid_degree / self.MEAN))
        assert worker_average_number > 0 and worker_average_number <= self.CLIENTS_NUMBER

        # sample workers_average_number * workers_count clients
        samples_clients = np.random.choice(self.CLIENTS_NUMBER, size=worker_average_number * workers_count, replace=False)
        # divide them into workers_count groups
        self.clients = np.split(samples_clients, workers_count)
        
        # print(ABS_PATH)
        ABS_PATH = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
        self.pdataset = PickleDataset(pickle_root=os.path.join(ABS_PATH, "FedLab/datasets/pickle_datasets"), 
                                      dataset_name="femnist")
        self.n_clusters = n_clusters
        

    def __call__(self, train=True):
        merged_datasets = [self.pdataset.get_dataset_pickle(dataset_type="train" if train else "test", client_id=client_ids[0]) for client_ids in self.clients]
        for i in range(1, len(self.clients[0])):
            for j in range(len(self.clients)):
                merged_datasets[j] += self.pdataset.get_dataset_pickle(dataset_type="train" if train else "test", client_id=self.clients[j][i])
        # merging everything into one dataset and getting workers data indices
        client_indices = [np.arange(len(merged_datasets[0]))]
        for _ in range(len(self.clients) - 1):
            client_indices.append(np.arange(client_indices[-1][-1] + 1, client_indices[-1][-1] + 1 + len(merged_datasets[_ + 1])))
        dataset = torch.utils.data.ConcatDataset(merged_datasets)
        workers_data_indices = client_indices
        return dataset, workers_data_indices

    def classes_count(self):
        return 62
    
def plot_dataset_distribution(dataset, workers_data_indices):
    plt.figure(figsize=(15, 5))
    plt.title("Dataset distribution")
    plt.xlabel("Samples number")
    plt.ylabel("Number of workers")
    plt.hist([len(subset_by_indices(dataset, indices)) for indices in workers_data_indices], bins=30)
    plt.show()

def tsne_clusterization(dataset, workers_data_indices, seed=42):
    '''
    Does TSNE feature-wise visualisation of dataset. Plots it. 
    After that it labels same clusterization deopending, on which data goes to which worker nd plots it. 
    '''
    workers_count = len(workers_data_indices)
    # Getting features
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=len(dataset), shuffle=False)
    inputs, targets = next(iter(dataloader))
    inputs = inputs.view(inputs.shape[0], -1)
    # Doing TSNE
    tsne = TSNE(n_components=2, random_state=seed)
    inputs = tsne.fit_transform(inputs)
    # Plotting
    plt.figure(figsize=(15, 5))
    plt.title("Dataset clusterization")
    plt.xlabel("Feature 1")
    plt.ylabel("Feature 2")
    # scattering points and labeling them according to workers_dataset_indices
    for i in range(workers_count):
        plt.scatter(inputs[workers_data_indices[i], 0], inputs[workers_data_indices[i], 1], label=f"Worker {i}")
    
def federated_dataset_clusterisation(dataset, workers_data_indices, clusters_number, 
                                     min_cluster_size=None,
                                     max_cluster_size=None,
                                     seed=42,
                                     visualize=False):
    '''
    Does KMeans clusterization of federated dataset and returns dictionary worker to cluster.
    '''
    workers_datasets = [subset_by_indices(dataset, indices) for indices in workers_data_indices]
    # print(workers_datasets[0][0][0].shape)
    averaged_workers_datasets = [dataset_average(workers_dataset).numpy().flatten() for workers_dataset in workers_datasets]
    # print(averaged_workers_datasets[0].shape)
    # Doing KMeans clusterization on averaged workers datasets
    clusterization = KMeansConstrained(n_clusters=clusters_number, size_min=min_cluster_size, size_max=max_cluster_size, random_state=seed)
    clusterization.fit_predict(averaged_workers_datasets)
    clusters = []
    # worker to cluster dict
    worker_to_cluster = {}
    for i in range(len(averaged_workers_datasets)):
        worker_to_cluster[i] = clusterization.labels_[i]
    if visualize:
        visualize_workers_clusterization(averaged_workers_datasets, clusterization.labels_, clusters_number, seed=seed)
    return worker_to_cluster

def visualize_workers_clusterization(averaged_workers_datasets, labels, clusters_number, seed=42):
    '''
    Does TSNE feature-wise visualisation of dataset average and marks clusters. 
    '''
    workers_count = len(averaged_workers_datasets)
    # Doing TSNE
    tsne = TSNE(n_components=2, random_state=seed)
    inputs = tsne.fit_transform(np.array(averaged_workers_datasets))
    # Plotting
    plt.figure(figsize=(15, 5))
    plt.title("Dataset clusterization")
    plt.xlabel("Feature 1")
    plt.ylabel("Feature 2")
    # scattering points and labeling them according to workers_dataset_indices
    for i in range(clusters_number):
        idx = labels == i
        plt.scatter(inputs[idx, 0], inputs[idx, 1], label=f"Cluster {i}")
    # save picture of the plot 
    plt.legend()
    plt.savefig("clusterization.png")
    plt.show()
