import numpy as np
from tqdm import tqdm
from sklearn.cluster import KMeans, kmeans_plusplus
from .base import BaseClient,  BaseServer    
from ..compressors import BaseCompressor
from torch.utils.tensorboard import SummaryWriter
from pathlib import Path
from copy import deepcopy

    
class KMeansClient(BaseClient):
    def __init__(self, id, data, config):
        super().__init__(id=id, data=data, config=config)
 
    def format_client_data(self, data):
        self.data = data['train']
        self.centers = None

    def format_client_config(self, config):        
        # Format client config according to tasks
        self.num_centers = config["num_clusters"]

    def round(self, init_centers):
        kmeans = KMeans(n_clusters=self.num_centers, init=init_centers, max_iter=1, n_init='auto')
        kmeans.fit(self.data)
        return kmeans.cluster_centers_
    
    
class KMeansServer(BaseServer):
    def __init__(self,  compressor: BaseCompressor, all_client_data, tb_writer : SummaryWriter, results_dir: Path,
                 T: int,  d : int, num_clusters: int):

        self.num_clusters = num_clusters
        self.T = T


        super().__init__(compressor=compressor, 
                         all_client_data=all_client_data,
                         tb_writer=tb_writer,
                         results_dir=results_dir,
                         d=d)
        

        seed_path = results_dir.parent.parent
        init_path = seed_path / "init_cluster_centers.npy"
        if init_path.exists():
            self.init = np.load(init_path)
        else:
            # initial centres by kmeans plusplus        
            self.init = kmeans_plusplus(n_clusters=self.num_clusters, X=self.all_data['train'])[0]
            np.save(init_path, self.init)


        
    def create_clients(self, all_client_data):
        # # Client_encoders is a dict with encoder for each client_id
        # client_encoders = self.compressor.create_client_encoders()
        self.clients = {}
        all_data_train = []
        all_data_test = []
        for client_id, client_data in all_client_data.items():
            self.clients[client_id] = KMeansClient(client_id, client_data, {"num_clusters": self.num_clusters})
            all_data_train.append(client_data['train'])
            all_data_test.append(client_data['test'])
            
        self.all_data = {"train"  : np.concatenate(all_data_train, axis=0),
                         "test": np.concatenate(all_data_test, axis=0)}
        
        

    def run(self):
        cluster_centers = deepcopy(self.init)
        ## Central KMeans used for computing distances 
        self.central_kmeans = KMeans(n_clusters=self.num_clusters, init=cluster_centers, max_iter=self.T, n_init='auto')
        self.central_kmeans.fit(self.all_data['train'])
        self.central_centers = self.central_kmeans.cluster_centers_

        ## KMeans scores is negative of KMeans objective
        central_scores = {"train": -1*self.central_kmeans.score(self.all_data['train']),
                            "test":  -1*self.central_kmeans.score(self.all_data['test'])}


        results = {"train_score" : [], "center_error":[], "test_score":[], "comm_bits": [], "dec_error" : []}

        # Compute error metric
        train_score, test_score, center_error = self.metric(cluster_centers)
        
        ## Log results
        results["train_score"].append(train_score)
        results["test_score"].append(test_score)
        results["center_error"].append(center_error)
        results["comm_bits"].append(0)
        results["dec_error"].append(0)
        self.log_results(results=results, idx=0)



        for idx in tqdm(range(self.T)):
            # Clients run a round and send their true outputs
            client_outputs = {client_id : client.round(cluster_centers) for client_id, client in self.clients.items()}
            
            dec_error = 0.0
            comm_bits = 0.0
            cluster_centers = np.zeros(cluster_centers.shape)
            client_arr = np.stack(list(client_outputs.values()), axis=0)
            true_centers = client_arr.mean(axis=0)
            
            for center_id in range(self.num_clusters):
                ## Compress each center separately
                decoded_center = self.compressor.compress(client_arr[:,center_id,:])    
                comm_bits_center = self.compressor.num_bits_float()
                ## Sve the decoded centers
                cluster_centers[center_id] = decoded_center 
                 
                # Compute error in decoding each center
                dec_error += np.linalg.norm(decoded_center - true_centers[center_id])**2
                # Add number of bits required per center
                comm_bits += comm_bits_center
                
            # Compute error metric
            train_score, test_score, center_error = self.metric(cluster_centers)
            
            ## Log results
            results["train_score"].append(train_score)
            results["test_score"].append(test_score)
            results["center_error"].append(center_error)
            results["comm_bits"].append(comm_bits)
            results["dec_error"].append(dec_error)
            self.log_results(results=results, idx=idx+1)
        
        ## Update central scores in results
        results["train_score_central"] = central_scores["train"]
        results["test_score_central"] = central_scores["test"]
        return results, cluster_centers

        
    def metric(self, cluster_centers):
        ## L_2 distance from central centers
        center_error = (np.linalg.norm(self.central_centers - cluster_centers, axis=1)**2).mean()
        
        ## KMeans objective for train and test data
        ## Dtype conversion as original cluster_centers_ is expected to be float32
        self.central_kmeans.cluster_centers_ = cluster_centers.astype(self.central_kmeans.cluster_centers_.dtype)

        train_score = -1*self.central_kmeans.score(X=self.all_data['train'])
        test_score = -1*self.central_kmeans.score(X=self.all_data['test'])

        return train_score, test_score, center_error




