import numpy as np
from tqdm import tqdm

from .base import BaseClient,  BaseServer    
from ..compressors import BaseCompressor
from torch.utils.tensorboard import SummaryWriter
from pathlib import Path

class LinRegClient(BaseClient):
    def __init__(self, id, data, config):
        super().__init__(id=id, data=data, config=config)

    def format_client_data(self, data):
        # Format client data according to tasks
        self.y_train = data['train'][1]
        self.X_train = data['train'][0]


    def format_client_config(self, config:dict):        
        # Format client config according to tasks
        self.batch_size = config["batch_size"]

    def round(self, w_server):
        batch_idx = np.random.choice(self.y_train.shape[0], min(self.batch_size,self.y_train.shape[0]))
        X_batch, y_batch = self.X_train[batch_idx], self.y_train[batch_idx]
        
        grad = ((X_batch @ w_server - y_batch).reshape(-1,1)*X_batch).mean(axis=0)
        
        return grad

class LinRegServer(BaseServer):
    def __init__(self, compressor: BaseCompressor, all_client_data, tb_writer : SummaryWriter, results_dir: Path,
                 T: int, lr: float, batch_size: float, d : int):

        self.d = d
        self.T = T
        self.batch_size = batch_size
        self.lr = lr        
        super().__init__(d = d, compressor=compressor, all_client_data=all_client_data, tb_writer=tb_writer, results_dir=results_dir)
                
    def create_clients(self, all_client_data):
        # # Client_encoders is a dict with encoder for each client_id
        # client_encoders = self.compressor.create_client_encoders()
        self.clients = {}
        all_X_train, all_y_train, all_X_test, all_y_test = [], [], [] , []
        for client_id, client_data in all_client_data.items():
            self.clients[client_id] = LinRegClient(client_id, client_data, config={"batch_size": self.batch_size})
            all_X_train.append(client_data['train'][0])
            all_y_train.append(client_data['train'][1])
            all_X_test.append(client_data['test'][0])
            all_y_test.append(client_data['test'][1])
            
        self.all_data = {"train":(np.concatenate(all_X_train, axis=0), np.concatenate(all_y_train)),
                         "test":(np.concatenate(all_X_test, axis=0), np.concatenate(all_y_test))}
                         
        

    def run(self):
        ## Initialize model
        # if results_dir
        w = np.ones(self.d)
        # w = np.random.rand(self.d,)
        results = {"train_loss" : [], "test_loss":[], "comm_bits": [], "dec_error" : []}

        for idx in tqdm(range(self.T)):
            # Clients run a round and send their true outputs
            client_outputs = {client_id : client.round(w) for client_id, client in self.clients.items()}
            client_arr = np.stack(list(client_outputs.values()), axis=0)

            # Get decoded grads             
            decoded_grad = self.compressor.compress(client_arr)
            comm_bits = self.compressor.num_bits_float()
            
            true_grad = np.stack(list(client_outputs.values()), axis=0).mean(axis=0)
            
            # Perform the decoding operation   
            w = w - self.lr*decoded_grad
            
            # Compute error metric
            train_loss, test_loss = self.metric(w)
            
            # Compute error in decoding
            dec_error = np.linalg.norm(decoded_grad - true_grad)**2
            
            ## Log results
            results["train_loss"].append(train_loss)
            results["test_loss"].append(test_loss)
            results["comm_bits"].append(comm_bits)
            results["dec_error"].append(dec_error)
            self.log_results(results=results, idx=idx)
            
        return results, w
        
    def metric(self, w):
        train_loss = 0.5*((self.all_data['train'][0] @ w - self.all_data['train'][1])**2).mean()
        test_loss = 0.5*((self.all_data['test'][0] @ w - self.all_data['test'][1])**2).mean()
        return train_loss, test_loss 