import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader 
from torchvision.transforms import ToTensor
from time import time
import numpy as np
import copy
from model import net
from server import Server_Class
from Split_Data import Non_iid_split_fmnist, Non_iid_split_cifar
from client import Client_Class
from utils import*

class Simulator():
    def __init__(self, args, logger, local_tr_data_loaders, local_te_data_loaders, device):
        self.args = args
        self.logger = logger
        self.Clients_list = None
        self.Clients_list_fedavg = None
        self.Server = None
        self.local_tr_data_loaders = local_tr_data_loaders
        self.local_te_data_loaders = local_te_data_loaders
        self.device = device


    def initialization(self, model):

        loss = nn.CrossEntropyLoss()

        self.Server = Server_Class.Server(self.args, model)
        
        if self.args.mask == 1:
        
            self.Clients_list = [Client_Class.Client(self.args, copy.deepcopy(self.Server.initial_model), loss, 
                                        client_id, tr_loader, te_loader, self.device, scheduler=None)
                                        for (client_id, (tr_loader, te_loader)) in enumerate(zip(self.local_tr_data_loaders, self.local_te_data_loaders))]
        
        else:
            self.Clients_list_fedavg = [Client_Class.Client(self.args, copy.deepcopy(self.Server.global_model), loss, 
                                        client_id, tr_loader, te_loader, self.device, scheduler=None)
                                        for (client_id, (tr_loader, te_loader)) in enumerate(zip(self.local_tr_data_loaders, self.local_te_data_loaders))]
    def FL_loop(self):

        best_acc = 0.
        keep_ratio_at_best_acc = 0.
        best_keep_ratio = 1.
        acc_at_best_keep_ratio = 0.
        acc_history = []
        density_history = []

        for rounds in np.arange(self.args.comm_rounds):
            begin_time = time()
            avg_acc =[]
            avg_loss =[]
            avg_density = []
            self.logger.info("-"*30 + "Epoch start" + "-"*30)

            sampled_clients = self.Server.sample_clients()

            for client_idx in sampled_clients:
                self.Clients_list[client_idx].local_training(rounds)        


            self.Server.aggregation(self.Clients_list, sampled_clients)

            self.Server.broadcast(self.Clients_list)

            if self.args.th_update == 1:
                for client in self.Clients_list:
                    client.th_update(self.Server.global_difference)

            for client_idx, client in enumerate(self.Clients_list):
                acc, loss, density = client.local_test()
                if acc != 'zero':
                    avg_acc.append(acc), avg_loss.append(loss), avg_density.append(density)
            print("acc_list: ", avg_acc)
            print("density_list: ",avg_density)

            ##removing results of clients that does not have batches in test data

            avg_acc_round = np.mean(avg_acc)
            avg_density_round = np.mean(avg_density)

            acc_history.append(avg_acc_round) #save the current average accuracy to the history
            density_history.append(avg_density_round)

            self.logger.info('round: %d, avg_acc: %.3f, avg_density: %.3f, spent: %.2f' %(rounds, avg_acc_round,
                                                                                                          avg_density_round, time()-begin_time))
            self.logger.info(">>>>> Accuracy history during training so far rounds: %d" %( rounds))
            self.logger.info(acc_history)
            self.logger.info(">>>>> Density history during training so far rounds: %d" %( rounds))
            self.logger.info(density_history)
            self.logger.info(">>>>> Average FLOPs so far rounds: %d" %(rounds))
            avg_total_FLOPS = self.FLOP_cal()
            self.logger.info(avg_total_FLOPS)
            cur_acc = avg_acc_round
            if self.args.mask:
                current_keep_ratio = avg_density_round
            if cur_acc > best_acc:
                best_acc = cur_acc
                if self.args.mask:
                    keep_ratio_at_best_acc = current_keep_ratio

            if self.args.mask and current_keep_ratio < best_keep_ratio:
                best_keep_ratio = current_keep_ratio
                acc_at_best_keep_ratio = cur_acc

        
        avg_total_FLOPS = self.FLOP_cal()

        self.logger.info(">>>>> Training process finish")
        self.logger.info("Best keep ratio {:.4f}, acc at best keep ratio {:.4f}".format(best_keep_ratio, acc_at_best_keep_ratio))
        self.logger.info("Best acc {:.4f}, keep ratio at best acc {:.4f}".format(best_acc, keep_ratio_at_best_acc))
        


        self.logger.info(">>>>> Accuracy history during training")
        self.logger.info(acc_history)
        self.logger.info(">>>>> Density history during training")
        self.logger.info(density_history)
        self.logger.info(">>>>> Average FLOPs")
        self.logger.info(avg_total_FLOPS)

    def FedAvg(self):

        best_acc = 0
        acc_history = []
        density_history = []

        for rounds in np.arange(self.args.comm_rounds):
            begin_time = time()
            avg_acc =[]
            avg_loss =[]
            self.logger.info("-"*30 + "Epoch start" + "-"*30)

            sampled_clients = self.Server.sample_clients()
            self.Server.broadcast(self.Clients_list_fedavg, sampled_clients)

            for client_idx, client in enumerate(self.Clients_list_fedavg):
                acc, loss = self.Clients_list_fedavg[client_idx].local_fedavg_test(self.Server.global_model)
                if acc != 'zero':
                    avg_acc.append(acc), avg_loss.append(loss)


            for client_idx in sampled_clients:
                self.Clients_list_fedavg[client_idx].local_training(rounds)        


            self.Server.aggregation(self.Clients_list_fedavg, sampled_clients)

            avg_acc_round = np.mean(avg_acc)

            acc_history.append(avg_acc_round)

            self.logger.info('round: %d, avg_acc: %.3f, spent: %.2f' %(rounds, avg_acc_round,
                                                                                                         time()-begin_time))

            cur_acc = avg_acc_round
            if cur_acc > best_acc:
                best_acc =cur_acc
        avg_total_FLOPS = self.FLOP_cal_fedavg()

        #####Check final accuracy
        self.Server.broadcast(self.Clients_list_fedavg, range(0, self.args.num_clients))
        final_acc =[]
        for client_idx, client in enumerate(self.Clients_list_fedavg):
            acc, loss = client.local_test()
            if acc != 'zero':
                avg_acc.append(acc), avg_loss.append(loss)
                final_acc.append(acc)
                self.logger.info('client_id: %d , final acc: %.3f' %(
                                client_idx, acc))
        final_avg_acc = np.mean(final_acc)

        self.logger.info(">>>>> Training process finish")
        self.logger.info("Best test accuracy {:.4f}".format(best_acc))  
        self.logger.info("Final test accuracy {:.4f}".format(final_avg_acc))
        self.logger.info(">>>>> Accuracy history during training")
        self.logger.info(acc_history)
        self.logger.info(">>>>> Average FLOPs")
        self.logger.info(avg_total_FLOPS)

    def get_sparsity(self, round):
        avg_density = 0 
        for client_idx, client in enumerate(self.Clients_list):
            density = print_layer_keep_ratio(self.Clients_list[client_idx].model, round, client_idx, self.logger)
            avg_density += density
        avg_density_round = avg_density/self.args.num_clients
        self.logger.info('round: %d, avg_density: %.4f' %(round, avg_density_round))

    def FLOP_cal(self):
        FLOPs = 0
        for client in self.Clients_list:
            FLOPs += client.FLOPs
        FLOPs *= 1/self.args.num_clients

        return FLOPs
    
    def FLOP_cal_fedavg(self):
        FLOPs = 0
        for client in self.Clients_list_fedavg:
            FLOPs += client.FLOPs
        FLOPs *= 1/self.args.num_clients            
        
        return FLOPs
