import torch
import os
dirname = os.path.dirname(__file__)
import sys
sys.path.append(os.path.join(dirname, '../'))
import numpy as np
from copy import deepcopy
from worker import Worker
from utils import Server_Trainer
from typing import List
import gc
from collections import defaultdict


class Server(Server_Trainer):
    def load_weights_to_model(self, model, flat_weight_vector):
        with torch.no_grad():
            model_dict = model.state_dict()
            pointer = 0
            for name, param in model_dict.items():
                num_params = param.numel()
                chunk = flat_weight_vector[pointer:pointer + num_params]
                assert chunk.numel() == num_params, f"Shape mismatch at {name}"
                param.copy_(chunk.view_as(param))
                pointer += num_params

            assert pointer == flat_weight_vector.numel(), "Some weights are unused or shape mismatch"
    
    def simulation_exp(self, worker_trainers: List[Worker], client_embeddings: List[torch.Tensor]):
        assert len(self.args.model_size) == len(self.args.model_dist)
        self.args.model_size = list(map(eval, self.args.model_size))
        self.args.model_dist = list(map(eval, self.args.model_dist))
        assert sum(self.args.model_dist) == len(worker_trainers)

        # Create a list of the maximum model size on each client 
        client_model_sizes = []
        for size, n_workers in zip(self.args.model_size, 
                                   self.args.model_dist):
            client_model_sizes.extend([size]*n_workers)

        self.global_prototypes = None
        # Start Training 
        for t in range(self.args.T):
            # Select participants 
            participants = np.random.choice(len(worker_trainers), 
                                            size=self.args.num_part, 
                                            replace=False)
            print('Participants list:', list(participants+1), flush=True)
            
            if self.global_prototypes is not None:
                for worker in worker_trainers:
                    worker.update_global_prototypes(self.global_prototypes)
            
            self.model.eval()
            embed_batch = torch.stack([client_embeddings[w_idx] for w_idx in participants], dim=0) 
            all_generated_weights = self.model(embed_batch)
            
            flat_grads = []
            client_prototypes = []
            for i, w_idx in enumerate(participants):
                worker = worker_trainers[w_idx]
                m_size = client_model_sizes[w_idx]
                
                generated_weights = all_generated_weights[i]
                self.load_weights_to_model(self.clientmodel, generated_weights)
                self.clientmodel.to(self.gpu)
                
                client_model = deepcopy(self.clientmodel)
                client_model.generate_mask(model_size=m_size, topk=True, bern=True)
                flat_gradient, loss, iteration, local_prototypes = worker.local_training(client_model, m_size)

                client_prototypes.append(local_prototypes)
                flat_grads.append(flat_gradient.detach())
                
                del client_model
                del flat_gradient
                del local_prototypes
                torch.cuda.empty_cache()
                gc.collect()
                
            flat_grads = [torch.nan_to_num(g, nan=0.0, posinf=0.0, neginf=0.0) for g in flat_grads]
            grad_matrix = torch.stack(flat_grads, dim=0)   
            del flat_grads  
            del generated_weights  
            del all_generated_weights     

            self.model.train()
            self.soptimizer.zero_grad()
            
            all_generated_weights = self.model(embed_batch)
            
            all_generated_weights.backward(grad_matrix)
            self.soptimizer.step()
            self.scheduler.step()
            
            del all_generated_weights
            del grad_matrix
            
            self.global_prototypes = defaultdict(list)
            for proto_dict in client_prototypes:
                for label, vec in proto_dict.items():
                    self.global_prototypes[label].append(vec)

            for label in self.global_prototypes:
                self.global_prototypes[label] = torch.stack(self.global_prototypes[label]).mean(dim=0)

            del client_prototypes
            
            sys.stdout.flush()
            print(f'Finish the training of Round {t}....')

    def eval(self, client_model_sizes:list):
        submodel_losses, submodel_accs = [], []

        for m_size in np.unique(client_model_sizes):
            temp_model = deepcopy(self.clientmodel)
            temp_model.generate_mask(m_size, topk=True, bern=True)
            test_loss, test_acc, _, _ = self.test(temp_model)
            print('global Model Size: {:.10f}\tTest Loss: {}\tTest Accuracy: {}'\
                    .format(m_size, test_loss, test_acc*100))
            submodel_accs.append(test_acc)
            submodel_losses.append(test_loss)
