import copy
import logging
import random

import numpy as np
import torch
import wandb
import pickle
import os
from .client import Client
import logging
import copy

class CentralizedAPI(object):
    def __init__(self, args, device, dataset, model,model_trainer=None):
        self.device = device
        self.args = args
        [
            train_data_num,
            test_data_num,
            train_data_global,
            test_data_global,
            train_data_local_num_dict,
            train_data_local_dict,
            test_data_local_dict,
            val_data_local_dict,
            class_num,
        ] = dataset
        self.train_global = train_data_global
        self.test_global = test_data_global
        # self.val_global = None
        self.train_data_num_in_total = train_data_num
        self.test_data_num_in_total = test_data_num

        self.client_list = []
    
        self.test_data_local_dict = test_data_local_dict
        self.val_data_local_dict = val_data_local_dict


        logging.info("model = {}".format(model))
        self.model_trainer = model_trainer
        logging.info("self.model_trainer = {}".format(self.model_trainer))

        self._setup_clients(
            self.train_global,
            self.test_global,model_trainer,
        )

    def _setup_clients(
        self,
        train_data_global,
        test_data_global,
        model_trainer,
    ):
       
        c = Client(
            train_data_global,
            test_data_global,
            self.args,
            self.device,
            model_trainer,
        )
        self.client_list.append(c)
        logging.info("############setup_clients (END)#############")

    def train(self):
        w_global = self.model_trainer.get_model_params()
        for round_idx in range(self.args.comm_round):
            for idx, client in enumerate(self.client_list):
                w_global = client.train(w_global)
            self.model_trainer.set_model_params(w_global)
    
            if round_idx % self.args.save_epoches == 0: 
                torch.save(self.model_trainer.model.state_dict(),os.path.join(self.args.run_folder, "%s_at_%s.pt" %(self.args.save_model_name,round_idx))) # check the fedavg model name


            if round_idx == self.args.comm_round - 1 or round_idx % self.args.frequency_of_the_test == 0:
                self._local_test_on_all_clients(round_idx)


    def _client_sampling(self):
       return False

    def _generate_validation_set(self, num_samples=10000):
        return False

    def _aggregate(self, w_locals):
        return False

    def _aggregate_noniid_avg(self, w_locals):
        return False

    def _local_test_on_all_clients(self, round_idx):

        logging.info("################local_test_on_all_clients : {}".format(round_idx))

        
        train_metrics = {"num_samples": [], "num_correct": [], "losses": [], "eo_gap":[],"dp_gap":[]}

        test_metrics = {"num_samples": [], "num_correct": [], "losses": [], "eo_gap":[],"dp_gap":[]}


        for idx, client in zip(self.args.users,self.client_list):
            """
            Note: for datasets like "fed_CIFAR100" and "fed_shakespheare",
            the training client number is larger than the testing client number
            """

            # train data
            train_local_metrics = client.local_test(False)
            train_metrics["num_samples"].append(copy.deepcopy(train_local_metrics["test_total"]))
            train_metrics["num_correct"].append(copy.deepcopy(train_local_metrics["test_correct"]))
            train_metrics["losses"].append(copy.deepcopy(train_local_metrics["test_loss"]))
            train_metrics["eo_gap"].append(copy.deepcopy(train_local_metrics["eo_gap"]))
            train_metrics["dp_gap"].append(copy.deepcopy(train_local_metrics["dp_gap"]))

            test_local_metrics = client.local_test(True)
            test_metrics["num_samples"].append(copy.deepcopy(test_local_metrics["test_total"]))
            test_metrics["num_correct"].append(copy.deepcopy(test_local_metrics["test_correct"]))
            test_metrics["losses"].append(copy.deepcopy(test_local_metrics["test_loss"]))
            test_metrics["eo_gap"].append(copy.deepcopy(test_local_metrics["eo_gap"]))
            test_metrics["dp_gap"].append(copy.deepcopy(test_local_metrics["dp_gap"]))

              

        train_acc = sum(train_metrics["num_correct"]) / sum(train_metrics["num_samples"])
        train_loss = sum(train_metrics["losses"]) / sum(train_metrics["num_samples"])
        train_dp_gap = sum(train_metrics["dp_gap"])
        train_eo_gap = sum(train_metrics["eo_gap"])

        test_acc = sum(test_metrics["num_correct"]) / sum(test_metrics["num_samples"])
        test_loss = sum(test_metrics["losses"]) / sum(test_metrics["num_samples"])
        test_dp_gap = sum(test_metrics["dp_gap"])
        test_eo_gap = sum(test_metrics["eo_gap"])

        logging.info('Train acc: {} Train Loss: {}, Test acc: {} Test Loss: {}'.format(train_acc,train_loss, test_acc,test_loss))
        logging.info('Train dp gap: {} Train eo gap: {}, Test dp gap: {} Test eo gap: {}'.format(train_dp_gap,train_eo_gap, test_dp_gap,test_eo_gap))


        if self.args.enable_wandb:
            wandb.log({"Train/Acc": train_acc, "round": round_idx})
            wandb.log({"Train/Loss": train_loss, "round": round_idx})
      
            wandb.log({"Test/Acc": test_acc, "round": round_idx})
            wandb.log({"Test/Loss": test_loss, "round": round_idx})
       
           

            for i in test_metrics.keys():
                if i not in ['num_samples','num_correct','losses']:
                    wandb.log({"Test/%s" % i: sum(test_metrics[i]) / len(self.args.users), "round": round_idx})
                    wandb.log({"Train/%s" % i: sum(train_metrics[i]) / len(self.args.users), "round": round_idx})


    def _local_test_on_validation_set(self, round_idx):
        return False

    def save(self):
        torch.save(self.model_trainer.model.state_dict(),os.path.join(self.args.run_folder, "%s.pt" %(self.args.save_model_name))) # check the fedavg model name

