import numpy as np
from tqdm import tqdm

from .base import BaseClient, BaseServer

from ..compressors import BaseCompressor
from torch.utils.tensorboard import SummaryWriter
from pathlib import Path
from copy import deepcopy

    
class PowerIterClient(BaseClient):
    def __init__(self, id, data, config):
        super().__init__(id=id, data=data, config=config)
 
    def format_client_data(self, data):
        self.data = data['train']
        self.cov = (self.data.T @ self.data)/ self.data.shape[0]

    def format_client_config(self, *args, **kwargs):
        return

    # def format_client_config(self, config):        
    #     # Format client config according to tasks
    #     self.n = config["n"]

    def round(self, x_hat):
        return self.cov @ x_hat
    
    
class PowerIterServer(BaseServer):
    def __init__(self, compressor: BaseCompressor, all_client_data, tb_writer : SummaryWriter, results_dir: Path,
                 T: int, d : int):
        self.d = d
        self.T = T
        seed_path = results_dir.parent.parent
        init_path = seed_path / "init_x_hat.npy"
        if init_path.exists():
            self.init = np.load(init_path)
        else:
            self.init = np.random.rand(self.d)
            self.init = self.init/np.linalg.norm(self.init)
            np.save(init_path, self.init)

        super().__init__(d = d, compressor=compressor, all_client_data=all_client_data, tb_writer=tb_writer, results_dir=results_dir)


        
    def create_clients(self, all_client_data):
        # # Client_encoders is a dict with encoder for each client_id
        # client_encoders = self.compressor.create_client_encoders()
        self.clients = {}
        train_cov = np.zeros((self.d,self.d))
        test_cov = np.zeros((self.d, self.d))
        num_pts_train = 0.0
        num_pts_test = 0.0
        
        ## Compute covariance matrices

        for client_id, client_data in all_client_data.items():
            ## Empty config passed to PowerIterClient as it doesn't have any configurable parameters.
            
            self.clients[client_id] = PowerIterClient(client_id, client_data, {})
            
            train_cov += client_data['train'].T @ client_data['train']
            num_pts_train += client_data['train'].shape[0]

            test_cov += client_data['test'].T @ client_data['test']
            num_pts_test += client_data['test'].shape[0]
            
        self.all_cov = {"train": train_cov / num_pts_train, "test": test_cov/num_pts_test}        
        
        _ ,evecs = np.linalg.eigh(self.all_cov['train'])

        self.true_evec = evecs[:,0]

    def run(self):

        ## Eigenvalues for train and test dataset
        true_eigvals = {dataset: (self.true_evec.T @ self.all_cov[dataset] @ self.true_evec).item() for dataset in ["train", "test"]}

        results = {"train_eigval" : [], "test_eigval":[], "l2_error":[], "comm_bits": [], "dec_error" : []}

        ## Initialization
        x_hat = deepcopy(self.init)
        train_eigval, test_eigval, l2_error  = self.metric(x_hat)
        results["train_eigval"].append(train_eigval)
        results["test_eigval"].append(test_eigval)
        results["l2_error"].append(l2_error)
        results["comm_bits"].append(0)
        results["dec_error"].append(0)
        self.log_results(results=results, idx=0)



        for idx in tqdm(range(self.T)):
            # Clients run a round and send their true outputs
            client_outputs = {client_id : client.round(x_hat) for client_id, client in self.clients.items()}
            client_arr = np.stack(list(client_outputs.values()), axis=0)

            # Get decoded grads             
            decoded_x_hat = self.compressor.compress(client_arr)
            comm_bits = self.compressor.num_bits_float()
            x_hat = decoded_x_hat/np.linalg.norm(decoded_x_hat)
            
            true_x_hat = self.all_cov['train'] @ x_hat
            true_x_hat = true_x_hat/np.linalg.norm(true_x_hat)
            
            # Compute error metric
            train_eigval, test_eigval, l2_error  = self.metric(x_hat)
            
            # Compute error in decoding
            dec_error = np.linalg.norm(x_hat - true_x_hat)**2
            
            ## Log results
            results["train_eigval"].append(train_eigval)
            results["test_eigval"].append(test_eigval)
            results["l2_error"].append(l2_error)
            results["comm_bits"].append(comm_bits)
            results["dec_error"].append(dec_error)
            self.log_results(results=results, idx=idx+1)
        
        results["train_central_eigval"] = true_eigvals['train']
        results["test_central_eigval"] = true_eigvals['test']
        return results, x_hat

        
    def metric(self, x_hat):
        l2_error = np.linalg.norm(x_hat - self.true_evec)**2
        train_eigval = (x_hat.T @ self.all_cov['train'] @ x_hat).item()
        test_eigval = (x_hat.T @ self.all_cov['test'] @ x_hat).item()
        return train_eigval, test_eigval, l2_error