import copy

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm
from sklearn.cluster import KMeans
from sklearn.preprocessing import normalize
from collections import defaultdict

import federated
import utils
from .base import *


class FedCIOController(FederatedController):
    def __init__(self, *args, num_clusters=3, seed=None, **kwargs):
        super(FedCIOController, self).__init__(*args, **kwargs)
        self.client_models = {id: copy.deepcopy(self.init_server_model) for id in self.client_loaders.keys()}
        self.clusters = None
        self.num_clusters = num_clusters
        self.cluster_models = {id: copy.deepcopy(self.init_server_model) for id in range(self.num_clusters)}
        self.seed = seed

    def client_clustering(self):
        client_indices = self.client_loaders.keys() 
        
        for id in client_indices:
            dataloader = self.client_loaders[id]["train"]
            federated.train_model(self.client_models[id],
                                  dataloader,
                                  config=self.config.local,
                                  verbose=True,
                                  device=self.device)
        all_parameters = []
        for id in client_indices:
            flatten_params = torch.nn.utils.parameters_to_vector(self.client_models[id].parameters()).detach().cpu()
            flatten_params = flatten_params.numpy()
            all_parameters.append(flatten_params)
        all_parameters = normalize(all_parameters)

        kmeans = KMeans(n_clusters=self.num_clusters, random_state=self.seed)
        kmeans.fit(all_parameters)

        self.clusters = defaultdict(set)
        for c_idx, cluster_id in zip(client_indices, kmeans.labels_):
            self.clusters[cluster_id].add(c_idx)
    
    def train(self, num_rounds):
        if self.clusters is None:
            self.client_clustering()
        print('Clusters:', self.clusters)
        self.evaluate(self.server_model, verbose=True)

        fed_num_rounds = self.config.federated.num_rounds 

        for rnd in range(num_rounds):
            for cluster_id in range(self.num_clusters):
                self.config.federated.num_rounds = 1
                cluster_loaders = {c_idx: self.client_loaders[c_idx] for c_idx in self.clusters[cluster_id]}
                self.cluster_models[cluster_id].to(self.device)
                federated.train_fed(self.cluster_models[cluster_id],
                                    cluster_loaders,
                                    self.config,
                                    eval_fn=None,
                                    device=self.device,
                                    is_llm=self.is_llm,
                                    tokenizer=self.tokenizer)

            # cluster_weights = [1/self.num_clusters * self.num_clusters]
            cluster_weights = []
            for cluster_id in self.clusters:
                cluster_weights.append(len(self.clusters[cluster_id]) / len(self.client_loaders))
            federated.aggr_params(self.server_model, self.cluster_models.values(), cluster_weights)
            verbose = True if (rnd + 1) % self.config.federated.log_frequency == 0 or (rnd + 1) == self.config.federated.num_rounds else False
            self.evaluate(self.server_model, verbose=verbose)

        self.config.federated.num_rounds = fed_num_rounds
        # for rnd in tqdm(range(num_rounds), desc="train_sisa"):
        #     for id in self.client_loaders.keys():
        #         model = self.client_models[id] 
        #         dataloader = self.client_loaders[id]["train"]
        #         federated.train_model(model,
        #                               dataloader,
        #                               config=self.config.local,
        #                               verbose=True,
        #                               device=self.device)
        #     metrics = self.evaluate()
        #     if (rnd + 1) % self.config.federated.log_frequency == 0 or (rnd + 1) == num_rounds:
        #         print("round:", rnd)
        #         for k, v in metrics.items():
        #             print("{}: {:.4f}".format(k, v))
    
    def leave(self, unlearned_client):
        super().leave(unlearned_client)
        
        self.client_loaders.pop(unlearned_client)
        self.client_models.pop(unlearned_client)
        target_cluster = None
        for cluster_id in self.clusters:
            if unlearned_client in self.clusters[cluster_id]:
                target_cluster = cluster_id
                break
            
        print(f'Restart cluster {target_cluster}')
        self.cluster_models[target_cluster] = copy.deepcopy(self.init_server_model)
        self.clusters[target_cluster].remove(unlearned_client)
    
    def join(self, client_id, dataloader, attach_to=None):
        super().join(client_id, dataloader, attach_to)

        self.client_models[client_id] = copy.deepcopy(self.init_server_model)
        self.evaluate()
    
    # def evaluate(self, verbose=False):
    #     super().prepare_evaluation()

    #     # client_weighted = {}
    #     # for id in self.client_loaders.keys():
    #     #     client_weighted[id] = len(self.client_loaders[id]["train"].dataset)
    #     # fed_samples = sum(client_weighted.values())
    #     # client_weighted = {k: v / fed_samples for k, v in client_weighted.items()}

    #     cluster_weights = {cluster_id: 1/self.num_clusters for cluster_id in range(self.num_clusters)}

    #     train_acc = self.check_sisa_accuracy(
    #         self.cluster_models,
    #         self.eval_loader["train"], 
    #         cluster_weights)
    #     test_acc = self.check_sisa_accuracy(
    #         self.cluster_models,
    #         self.eval_loader["test"], 
    #         cluster_weights)

    #     metrics = {"train_acc": train_acc, "test_acc": test_acc}
    #     utils.log_stats(self.log_path, **metrics)
    #     if verbose:
    #         print('eval:', metrics)
    #     return metrics
    
    # def check_sisa_accuracy(self, cluster_models, dataloader, pred_weights):
    #     num_correct = 0
    #     num_samples = 0
    #     with torch.no_grad():
    #         for x, y in dataloader:
    #             x = x.to(device=self.device)
    #             y = y.to(device=self.device)
    #             scores = None
    #             # ensemble predictions from all (retained) clients
    #             for cluster_id, model in cluster_models.items():
    #                 weight = pred_weights[cluster_id]
    #                 model.to(self.device)
    #                 model.eval()
    #                 if scores is None:
    #                     scores = weight * model(x)
    #                 else:
    #                     scores += weight * model(x)
    #                 model.to("cpu")
    #             _, predictions = scores.max(1)
    #             num_correct += (predictions == y).sum()
    #             num_samples += predictions.size(0)
    #     return float(num_correct)/float(num_samples)*100