
import torch
import time
import os
dirname = os.path.dirname(__file__)
import sys
sys.path.append(os.path.join(dirname, '../'))
import numpy as np
from copy import deepcopy
from typing import List
import gc

from worker import Worker
from utils import Server_Trainer, log, count_params, calculate_entropy


class Server(Server_Trainer):
    def run_with_multiprocessing(self):
        raise NotImplementedError
    
    def simulation_exp(self, worker_trainers: List[Worker]):
        assert len(self.args.model_size) == len(self.args.model_dist)
        self.args.model_size = list(map(eval, self.args.model_size))
        self.args.model_dist = list(map(eval, self.args.model_dist))
        assert sum(self.args.model_dist) == len(worker_trainers)

        print(self.model)
        log(self.args.save_log, f'')
        n_network, i = 0, 0        
        for name, param in self.model.named_parameters():
            n_network += param.numel()
            log(self.args.save_log, f'[{i}] Name: {name}, Shape: {tuple(param.shape)}, Elements: {param.numel()}')
            i += 1
        log(self.args.save_log, f'n_network: {n_network}')
        log(self.args.save_log, f'')
                

        client_model_sizes = []
        for size, n_workers in zip(self.args.model_size, self.args.model_dist):
            client_model_sizes.extend([size]*n_workers)

        model_params = self.model.state_dict()
        model_params_layer = deepcopy(model_params)
        lambdais = [{k: torch.zeros_like(v) for k, v in deepcopy(model_params).items()} for _ in range(len(worker_trainers))]
        agg_params = [deepcopy(model_params) for _ in range(self.args.num_part)]
        lambdais_m = [{k: torch.zeros_like(v) for k, v in deepcopy(model_params).items()} for _ in range(len(worker_trainers))]
        

        # Start Training 
        global_model_params = deepcopy(self.model.state_dict())
        global_model_params_m = deepcopy(self.model.state_dict())
        global_model_params_t_1 = deepcopy(self.model.state_dict())
        submodel_accs_R, submodel_losses_R, submodel_local_accs_R, submodel_local_accs_R_0, submodel_local_accs_R_1, submodel_local_accs_R_2, submodel_local_accs_R_3, submodel_accs_R_0, submodel_accs_R_1, submodel_accs_R_2, submodel_accs_R_3 = [], [], [], [], [], [], [], [], [], [], []
        for t in range(self.args.T):
            time_r_start = time.time() 

            for name, param in global_model_params.items(): 
                linshi_global_model_params_m = self.args.m * global_model_params[name] + (1-self.args.m) * global_model_params_t_1[name]
                global_model_params[name] = (2 * (t+1) + self.args.alpha_p) * (linshi_global_model_params_m + global_model_params_t_1[name]) / 2 - (t+1) * global_model_params_m[name]
                global_model_params[name] = global_model_params[name] / ((t+1) + self.args.alpha_p)
                global_model_params_m[name] = deepcopy(linshi_global_model_params_m)

            self.model.load_state_dict(global_model_params_m)
            
            # Select participants 
            participants = np.random.choice(len(worker_trainers), size=self.args.num_part, replace=False)
            log(self.args.save_log, f'')
            log(self.args.save_log, f'[R{t}]Participants list:{list(participants+1)}')
            log(self.args.save_log, f'[R{t}]{self.args.save_log}')
            train_loss = 0.0
            time_train_st = time.time()
            self.model.train()    
            ww_idx = 0
            for w_idx in participants:  

                time_Cl_st = time.time()
                #log(self.args.save_log, f'')
                worker = worker_trainers[w_idx]
                m_size = client_model_sizes[w_idx]


                client_model = deepcopy(self.model)
                linshi_param = deepcopy(client_model.state_dict())
                time_trainCl_st = time.time()
                lambdai = deepcopy(lambdais[w_idx])
                _, loss, iteration = worker.local_training(client_model, lambdai) 
                time_trainCl_end = time.time() - time_trainCl_st

                time_linshi0 = time.time()

                client_param = deepcopy(client_model.state_dict())
                for name, param in client_param.items(): 
                    lambdais[w_idx][name] += - self.args.lrd * self.args.delta * client_param[name] + self.args.lrd * self.args.delta * linshi_param[name]
                    agg_params[ww_idx][name] = deepcopy(client_param[name])   
                    linshi_lambdais_m = lambdais[w_idx][name]
                    lambdais[w_idx][name] = (2 * (t+1) + self.args.alpha_p) * (linshi_lambdais_m + lambdai[name]) / 2 - (t+1) * lambdais_m[w_idx][name]
                    lambdais[w_idx][name] = lambdais[w_idx][name] / ((t+1) + self.args.alpha_p)
                    lambdais_m[w_idx][name] = deepcopy(linshi_lambdais_m)
                total_nonzero_elements, total_elements = count_params(lambdais[w_idx])
                log(self.args.save_log, f'[R{t}]w_idx: {w_idx} m_size{m_size}-------------------------------------------------------- lambdais[w_idx]: total_nonzero_elements/total_elements: {total_nonzero_elements}/{total_elements}; S: {1-(total_nonzero_elements/total_elements):.6f}   m_size{m_size}')

                train_loss = train_loss + loss/len(participants)
                torch.cuda.empty_cache()
                time_Cl_end = time.time() - time_Cl_st
                ww_idx += 1
            global_model_params = deepcopy(self.model.state_dict())
            global_model_params_t_1 = deepcopy(self.model.state_dict())

            tot_mask, tot_param = {}, {}
            for name, param in self.model.state_dict().items():                
                tot_mask[name] = torch.zeros_like(param)
                tot_param[name] = torch.zeros_like(param)
            for www_idx in range(self.args.num_part):  
                for name, param in agg_params[www_idx].items():
                    tot_param[name] += param
                    tot_mask[name] += (param  != 0)
            for name in global_model_params.keys():            
                avg_param = tot_param[name] / (tot_mask[name])
                global_model_params[name] = torch.nan_to_num(avg_param, nan=0.0, posinf=0.0, neginf=0.0)                
            self.model.load_state_dict(global_model_params)

            time_train_end = time.time() - time_train_st           

            # %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
            time_eval_st = time.time()

            if (t+1) % self.args.save_freq == 0 or (t==0) or (t >= self.args.T-50):
                submodel_losses, submodel_accs = [], []
                submodel_local_accs = [0.0]*len(client_model_sizes)

                for m_size in self.args.model_size:
                    test_model = deepcopy(self.model)
                    test_loss, test_acc, correct, total = self.test(test_model)                    
                    submodel_accs.append(test_acc)
                    submodel_losses.append(test_loss) 

                    this_local_accs = []
                    for i, (worker, size) in enumerate(zip(worker_trainers, client_model_sizes)):
                        if size == m_size:
                            _, acc, _, _ = worker.test(test_model)
                            submodel_local_accs[i] = acc
                            this_local_accs.append(acc)
                    
                log(self.args.save_size, f'')
                log(self.args.save_acc, f'[R{t}]Average accuracy: {np.average(submodel_accs)*100:.2f}%; var: {np.var(submodel_accs)*100:.2f}%; maxmin: {(np.max(submodel_accs) - np.min(submodel_accs))*100:.2f}%; max: {np.max(submodel_accs)*100:.2f}%; min: {np.min(submodel_accs)*100:.2f}%; \tLocal mean accuracy: {np.average(submodel_local_accs)*100:.2f}%; var: {np.var(submodel_local_accs)*100:.2f}%; maxmin: {(np.max(submodel_local_accs) - np.min(submodel_local_accs))*100:.2f}%; max: {np.max(submodel_local_accs)*100:.2f}%; min: {np.min(submodel_local_accs)*100:.2f}%')
                
                submodel_accs_R.append(np.average(submodel_accs))
                submodel_losses_R.append(np.average(submodel_losses))
                submodel_local_accs_R.append(np.average(submodel_local_accs))
                sys.stdout.flush()
                log(self.args.save_eval_acc, f'')
                log(self.args.save_eval_acc, f'[R{t}END-Mean1] Global_acc_final: {submodel_accs_R[-1]*100:.2f}%; Llobal_acc_final: {submodel_local_accs_R[-1]*100:.2f}%;')
               
                log(self.args.save_eval_loss, f'[R{t}END-Mean1] Global_losses_avg: {submodel_losses_R[-1]:.4f}; submodel_losses: {submodel_losses}')

            time_eval_end = time.time() - time_eval_st
            time_r_end = time.time() - time_r_start
            log(self.args.save_log, f'')
            log(self.args.save_log, f'')
            log(self.args.save_train_loss, f'[R{t}END]train_loss: {train_loss:.4f}; time_r: {time_r_end/60:.2f}min; time_train: {time_train_end/60:.4f}min; time_eval: {time_eval_end/60:.4f}min')
            
            log(self.args.save_eval_loss, f'[R{t}END-Mean1] Global_acc_final: {submodel_accs_R[-1]*100:.2f}%')
            log(self.args.save_eval_loss, f'[R{t}END-Mean1] Llobal_acc_final: {submodel_local_accs_R[-1]*100:.2f}%')
            log(self.args.save_log, f'[R{t}END-Mean50] Global_acc_avg: {np.average(submodel_accs_R[-50:-1])*100:.2f}%')
            log(self.args.save_log, f'[R{t}END-Mean20] Global_acc_avg: {np.average(submodel_accs_R[-20:-1])*100:.2f}%')
            log(self.args.save_log, f'[R{t}END-Mean10] Global_acc_avg: {np.average(submodel_accs_R[-10:-1])*100:.2f}%')
            log(self.args.save_log, f'[R{t}END-Mean50] Lobal_acc_avg: {np.average(submodel_local_accs_R[-50:-1])*100:.2f}%')
            log(self.args.save_log, f'[R{t}END-Mean20] Lobal_acc_avg: {np.average(submodel_local_accs_R[-20:-1])*100:.2f}%')
            log(self.args.save_log, f'[R{t}END-Mean10] Lobal_acc_avg: {np.average(submodel_local_accs_R[-10:-1])*100:.2f}%')
            log(self.args.save_log, f'')
            
        