import torch
import os
import numpy as np
import h5py
from tqdm import tqdm

from utils.model_utils import Metrics
import copy
# from Setting import *

class Server:
    def __init__(self, train_A, train_B, experiment, device, dataset,algorithm, model, model_server, batch_size, learning_rate ,
                 num_glob_iters, local_epochs, optimizer,num_users, times,args):

        # Set up the main attributes
        self.device = device
        self.dataset = dataset
        self.num_glob_iters = num_glob_iters
        self.local_epochs = local_epochs
        self.batch_size = batch_size
        self.learning_rate = learning_rate
        self.total_train_samples = 0
        self.model = copy.deepcopy(model)
        # self.client_model = copy.deepcopy(client_model)
        self.model_server = copy.deepcopy(model_server)
        self.users = []
        self.selected_users = []
        self.num_users = num_users
        # self.beta = beta
        # self.L_k = L_k
        self.algorithm = algorithm
        self.rs_train_acc, self.rs_train_loss, self.rs_glob_acc,self.rs_train_acc_per, self.rs_train_loss_per, self.rs_glob_acc_per , self.rs_avg_acc, self.rs_avg_acc_per = [], [], [], [], [], [], [], []
        self.rs_c_gen_acc=[]
        self.rs_test_loss=[]
        self.rs_rec_loss = []
        self.rs_kt_loss = []
        self.rs_global_rec_loss = []
        self.rs_global_kt_loss = []
        self.rs_local_f1_acc=[]
        self.c_avg_test=[]
        self.times = times
        self.experiment = experiment
        self.sub_data = 0

        # Initialize the server's grads to zeros
        #for param in self.model.parameters():
        #    param.data = torch.zeros_like(param.data)
        #    param.grad = torch.zeros_like(param.data)
        #self.send_parameters()

    def get_data(self,train,test):
        if(self.sub_data == 1):
            train = train[int(0.95*len(train)):]
            test = test[int(0.8*len(test)):]
        else:
            train = train[int(0.8*len(train)):]
            test = test[int(0.6*len(test)):]
        return train, test

    def get_partion(self, total_users):
        if(self.sub_data):
            if(self.sub_data == 1):
                partion = int(0.9 * total_users)
            else:
                partion = int(0.5 * total_users)
        randomList = np.random.choice(range(0, total_users), partion, replace =False)
        return randomList

    def aggregate_grads(self):
        assert (self.users is not None and len(self.users) > 0)
        for param in self.model.parameters():
            param.grad = torch.zeros_like(param.data)
        for user in self.users:
            self.add_grad(user, user.train_samples / self.total_train_samples)

    def add_grad(self, user, ratio):
        user_grad = user.get_grads()
        for idx, param in enumerate(self.model.parameters()):
            param.grad = param.grad + user_grad[idx].clone() * ratio

    def send_parameters(self):
        assert (self.users is not None and len(self.users) > 0)
        for user in self.users:
            user.set_parameters(self.model)

    
    def send_meta_parameters(self):
        assert (self.users is not None and len(self.users) > 0)
        for user in self.users:
            user.set_meta_parameters(self.model)


    def send_meta_parameters_totest(self):
        assert (self.users is not None and len(self.test_users) > 0)
        for user in self.test_users:
            user.set_meta_parameters(self.model)        

    def add_parameters(self, user, ratio):
        model = self.model.parameters()
        for server_param, user_param in zip(self.model.parameters(), user.get_parameters()):
            server_param.data = server_param.data + user_param.data.clone() * ratio

    def aggregate_parameters(self):
        assert (self.users is not None and len(self.users) > 0)
        for param in self.model.parameters():
            param.data = torch.zeros_like(param.data)
        total_train = 0
        #if(self.num_users = self.to)
        for user in self.selected_users:
            total_train += user.train_samples
        for user in self.selected_users:
            self.add_parameters(user, user.train_samples / total_train)

    def add_parameters_A(self, user, ratio):
        model = self.model.parameters()
        for server_key in self.model.state_dict().keys():
            if "A" in server_key :
                self.model.state_dict()[server_key] += user.model.state_dict()[server_key] * ratio

    def add_parameters_B(self, user, ratio):
        model = self.model.parameters()
        for server_key in self.model.state_dict().keys():
            if "B" in server_key :
                self.model.state_dict()[server_key] +=  user.model.state_dict()[server_key] * ratio


    def aggregate_parameters_multimodal(self):
        assert (self.users is not None and len(self.users) > 0)
        for param in self.model.parameters():
            param.data = torch.zeros_like(param.data)
        total_train = 0
        # if(self.num_users = self.to)
        n_A=0
        n_B=0
        for user in self.selected_users:
            # total_train += user.train_samples
            if user.modality == "A" or user.modality == "AB":
                n_A += user.get_weight()
                # print("weight of user is",user.get_weight())
            if user.modality == "B"  or user.modality == "AB":
                n_B += user.get_weight()
        # print("user weight is",n_A)
        # print("user weight is", n_B)

        for user in self.selected_users:
            if user.modality == "A" or user.modality == "AB":
                self.add_parameters_A(user, user.get_weight()/n_A)
            if user.modality == "B" or user.modality == "AB":
                self.add_parameters_B(user, user.get_weight()/ n_B)

        # print(self.model.state_dict())
        # print("finish aggregation")
    def aggregate_meta_parameters(self):
        assert (self.users is not None and len(self.users) > 0)
        for param in self.model.parameters():
            param.data = torch.zeros_like(param.data)
        total_train = 0
        #if(self.num_users = self.to)
        for user in self.train_users:
            total_train += user.train_samples
        for user in self.train_users:
            self.add_parameters(user, user.train_samples / total_train)

    def save_model(self):
        model_path = os.path.join("models", self.dataset[1])
        if not os.path.exists(model_path):
            os.makedirs(model_path)
        torch.save(self.model, os.path.join(model_path, "server" + ".pt"))

    def load_model(self):
        model_path = os.path.join("models", self.dataset, "server" + ".pt")
        assert (os.path.exists(model_path))
        self.model = torch.load(model_path)

    def model_exists(self):
        return os.path.exists(os.path.join("models", self.dataset, "server" + ".pt"))
    
    def select_users(self, round, fac_users):
        '''selects num_clients clients weighted by number of samples from possible_clients
        Args:
            num_clients: number of clients to select; default 20
                note that within function, num_clients is set to
                min(num_clients, len(possible_clients))


        Return:
            list of selected clients objects
        '''
        if(fac_users == 1):
            print("All users are selected")
            return self.users
        num_users = int(fac_users * len(self.users))
        num_users = min(num_users, len(self.users))
        if fac_users<1.0:
          np.random.seed(100)
        return np.random.choice(self.users, num_users, replace=False) #, p=pk)

    def select_clients(self, clients):
        """Selects clients to communicate with.

        Args:
            clients: a list of Client objects

        Returns:
            A list of selected Client objects
        """
        n_selected_clients = int(len(clients) * self.frac)
        selected_clients = np.random.choice(
            clients, n_selected_clients, replace=False)
        return selected_clients

    def meta_split_users(self, ratio=0.8):
        len_train = int(len(self.users)*0.8)
        self.train_users = self.users[0:len_train]
        self.test_users = self.users[len_train:]

    def select_sub_train_users(self, num_users):
        if(num_users >= len(self.train_users)):
            print("All users are selected")
            return self.train_users

        num_users = min(num_users, len(self.train_users))
        #np.random.seed(round)
        return np.random.choice(self.train_users, num_users, replace=False) #, p=pk)
        
    # define function for persionalized agegatation.
    def persionalized_update_parameters(self,user, ratio):
        # only argegate the local_weight_update
        for server_param, user_param in zip(self.model.parameters(), user.local_weight_updated):
            server_param.data = server_param.data + user_param.data.clone() * ratio


    def persionalized_aggregate_parameters(self):
        assert (self.users is not None and len(self.users) > 0)

        # store previous parameters
        previous_param = copy.deepcopy(list(self.model.parameters()))
        for param in self.model.parameters():
            param.data = torch.zeros_like(param.data)
        total_train = 0
        #if(self.num_users = self.to)
        for user in self.selected_users:
            total_train += user.train_samples

        for user in self.selected_users:
            self.add_parameters(user, user.train_samples / total_train)
            #self.add_parameters(user, 1 / len(self.selected_users))

        # aaggregate avergage model with previous model using parameter beta 
        for pre_param, param in zip(previous_param, self.model.parameters()):
            param.data = (1 - self.beta)*pre_param.data + self.beta*param.data

    def persionalized_aggregate_meta_parameters(self):
        assert (self.users is not None and len(self.users) > 0)

        # store previous parameters
        previous_param = copy.deepcopy(list(self.model.parameters()))
        for param in self.model.parameters():
            param.data = torch.zeros_like(param.data)
        total_train = 0
        #if(self.num_users = self.to)
        for user in self.train_users:
            total_train += user.train_samples

        for user in self.train_users:
            self.add_parameters(user, user.train_samples / total_train)
            #self.add_parameters(user, 1 / len(self.selected_users))

        # aaggregate avergage model with previous model using parameter beta 
        for pre_param, param in zip(previous_param, self.model.parameters()):
            param.data = (1 - self.beta)*pre_param.data + self.beta*param.data

    # Save loss, accurancy to h5 fiel
    def save_results(self):
        dir_path = "./results"
        if not os.path.exists(dir_path):
            os.makedirs(dir_path)
        alg = self.dataset[1] + "_" + self.algorithm
        alg = alg + "_" + str(self.learning_rate) + "_" + str(self.beta) + "_" + str(self.L_k) + "_" + str(self.num_users) + "u" + "_" + str(self.batch_size) + "b" + "_" + str(self.local_epochs) 
        if(self.algorithm == "pFedMe" or self.algorithm == "pFedMe_p"):
            alg = alg + "_" + str(self.K) + "_" + str(self.personal_learning_rate)
        elif(self.algorithm == "SSGD"):
             alg = alg + "_" + str(self.K)
        if(self.sub_data):
            alg = alg + "_" + "subdata"
        alg = alg + "_" + str(self.times)
        if (len(self.rs_glob_acc) != 0 &  len(self.rs_train_acc) & len(self.rs_train_loss)) :
            with h5py.File("./results/"+'{}.h5'.format(alg, self.local_epochs), 'w') as hf:
                hf.create_dataset('rs_glob_acc', data=self.rs_glob_acc)
                hf.create_dataset('rs_train_acc', data=self.rs_train_acc)
                hf.create_dataset('rs_train_loss', data=self.rs_train_loss)
                hf.create_dataset('rs_avg_acc', data=self.rs_avg_acc)
                hf.close()
        
        # store persionalized value
        alg = self.dataset[1] + "_" + self.algorithm + "_p"
        alg = alg  + "_" + str(self.learning_rate) + "_" + str(self.beta) + "_" + str(self.L_k) + "_" + str(self.num_users) + "u" + "_" + str(self.batch_size) + "b"+ "_" + str(self.local_epochs)
        if(self.algorithm == "pFedMe" or self.algorithm == "pFedMe_p"):
            alg = alg + "_" + str(self.K) + "_" + str(self.personal_learning_rate)
        if(self.sub_data):
            alg = alg + "_" + "subdata"
        alg = alg + "_" + str(self.times)
        if (len(self.rs_glob_acc_per) != 0 &  len(self.rs_train_acc_per) & len(self.rs_train_loss_per)) :
            with h5py.File("./results/"+'{}.h5'.format(alg, self.local_epochs), 'w') as hf:
                hf.create_dataset('rs_avg_acc', data=self.rs_avg_acc_per)
                hf.create_dataset('rs_glob_acc', data=self.rs_glob_acc_per)
                hf.create_dataset('rs_train_acc', data=self.rs_train_acc_per)
                hf.create_dataset('rs_train_loss', data=self.rs_train_loss_per)
                hf.close()

    def test(self):
        '''tests self.latest_model on given clients
        '''
        num_samples = []
        tot_correct = []
        losses = []
        mean_accurancy = []
        for c in self.users:
            ct, ns, ma = c.test()
            tot_correct.append(ct*1.0)
            num_samples.append(ns)
            mean_accurancy.append(ma)
        ids = [c.id for c in self.users]

        return ids, num_samples, tot_correct, mean_accurancy

    def train_error_and_loss(self):
        num_samples = []
        tot_correct = []
        losses = []
        for c in self.users:
            ct, cl, ns = c.train_error_and_loss() 
            tot_correct.append(ct*1.0)
            num_samples.append(ns)
            losses.append(cl*1.0)
        
        ids = [c.id for c in self.users]
        #groups = [c.group for c in self.clients]

        return ids, num_samples, tot_correct, losses

    def test_persionalized_model(self):
        '''tests self.latest_model on given clients
        '''
        num_samples = []
        tot_correct = []
        mean_accurancy = []
        for c in self.users:
            ct, ns, ma = c.test_persionalized_model()
            tot_correct.append(ct*1.0)
            num_samples.append(ns)
            mean_accurancy.append(ma)
        ids = [c.id for c in self.users]

        return ids, num_samples, tot_correct, mean_accurancy

    def train_error_and_loss_persionalized_model(self):
        num_samples = []
        tot_correct = []
        losses = []
        for c in self.users:
            ct, cl, ns = c.train_error_and_loss_persionalized_model() 
            tot_correct.append(ct*1.0)
            num_samples.append(ns)
            losses.append(cl*1.0)
        
        ids = [c.id for c in self.users]
        #groups = [c.group for c in self.clients]

        return ids, num_samples, tot_correct, losses

    def test_gen(self):
        '''tests generalization performance of each personalized model
        '''
        c_gen_acc = []
        for c in self.users:
            num_samples = 0.
            tot_correct = 0.
            for n in self.users:
                c_test_model = copy.deepcopy(c.persionalized_model_bar)
                ct, ns = n.test_gen(c_test_model)
                tot_correct += ct * 1.0
                num_samples += ns
            # print("***Test Client ",c.id,":",tot_correct / num_samples)
            c_gen_acc.append(tot_correct / num_samples)

        return c_gen_acc

    def evaluating_global(self,i):
        stats = self.test()
        # stats_train = self.train_error_and_loss()
        # self.metrics.accuracies.append(stats)
        # self.metrics.train_accuracies.append(stats_train)
        gl_test = np.sum(stats[2])*1.0/np.sum(stats[1])
        # gl_train = np.sum(stats_train[3])*1.0/np.sum(stats_train[2])
        self.rs_glob_acc.append(gl_test)
        # self.global_data_train.append(gl_train)
        tqdm.write('At round {} global testing accuracy: {}'.format(i, gl_test))
        # tqdm.write('At round {} global training accuracy: {}'.format(i, gl_train))
        # tqdm.write('At round {} global training loss: {}'.format(i, np.dot(stats_train[4], stats_train[2])*1.0/np.sum(stats_train[2])))




    def evaluate(self):
        stats = self.test()  
        stats_train = self.train_error_and_loss()
        glob_acc = np.sum(stats[2])*1.0/np.sum(stats[1])
        train_acc = np.sum(stats_train[2])*1.0/np.sum(stats_train[1])
        glob_acc_avg = np.mean(stats[3])
        # train_loss = np.dot(stats_train[3], stats_train[1])*1.0/np.sum(stats_train[1])
        #train_loss = sum([x * y for (x, y) in zip(stats_train[1], stats_train[3])]).item() / np.sum(stats_train[1])
        train_loss = np.mean(list(stats_train[3]))
        self.rs_avg_acc.append(glob_acc_avg)
        self.rs_glob_acc.append(glob_acc)
        self.rs_train_acc.append(train_acc)
        self.rs_train_loss.append(train_loss)
        if(self.experiment):
            self.experiment.log_metric("glob_acc",glob_acc)
            self.experiment.log_metric("train_acc",train_acc)
            self.experiment.log_metric("train_loss",train_loss)
            self.experiment.log_metric("glob_avg",glob_acc_avg)
        #print("stats_train[1]",stats_train[3][0])
        print("Average Global Accurancy: ", glob_acc)
        print("Average Global AVG Accurancy: ", glob_acc_avg)
        print("Average Global Trainning Accurancy: ", train_acc)
        print("Average Global Trainning Loss: ",train_loss)

    def evaluate_personalized_model(self):
        stats = self.test_persionalized_model()  
        stats_train = self.train_error_and_loss_persionalized_model()
        glob_acc = np.sum(stats[2])*1.0/np.sum(stats[1])
        train_acc = np.sum(stats_train[2])*1.0/np.sum(stats_train[1])
        glob_acc_avg = np.mean(stats[3])
        # train_loss = np.dot(stats_train[3], stats_train[1])*1.0/np.sum(stats_train[1])
        #train_loss = sum([x * y for (x, y) in zip(stats_train[1], stats_train[3])]).item() / np.sum(stats_train[1])
        train_loss = np.mean(list(stats_train[3]))
        self.rs_glob_acc_per.append(glob_acc)
        self.rs_train_acc_per.append(train_acc)
        self.rs_train_loss_per.append(train_loss)
        self.rs_avg_acc_per.append(glob_acc_avg)

        #Evalutate the generalization performance of personalized models C-GEN
        c_gen_acc = np.average(self.test_gen())
        self.rs_c_gen_acc.append(c_gen_acc)

        if(self.experiment):
            self.experiment.log_metric("glob_acc_persionalized",glob_acc)
            self.experiment.log_metric("train_acc_persionalized",train_acc)
            self.experiment.log_metric("train_loss_persionalized",train_loss)
            self.experiment.log_metric("glob_persionalized_avg",glob_acc_avg)
            self.experiment.log_metric("C-GEN", c_gen_acc)

        #print("stats_train[1]",stats_train[3][0])
        print("Average Personal Accurancy: ", glob_acc)
        print("Average Personal Mean Accurancy: ", glob_acc_avg)
        print("Average Personal Trainning Accurancy: ", train_acc)
        print("Average Personal Trainning Loss: ",train_loss)
        print("Average C-GEN: ", c_gen_acc)


    def evaluate_one_step(self):
        for c in self.users:
            c.train_one_step()

        stats = self.test()  
        stats_train = self.train_error_and_loss()

        # set local model back to client for training process.
        for c in self.users:
            c.update_parameters(c.local_model)

        glob_acc = np.sum(stats[2])*1.0/np.sum(stats[1])
        train_acc = np.sum(stats_train[2])*1.0/np.sum(stats_train[1])
        glob_acc_avg = np.mean(stats[3])
        # train_loss = np.dot(stats_train[3], stats_train[1])*1.0/np.sum(stats_train[1])
        #train_loss = sum([x * y for (x, y) in zip(stats_train[1], stats_train[3])]).item() / np.sum(stats_train[1])
        train_loss = np.mean(list(stats_train[3]))
        self.rs_glob_acc_per.append(glob_acc)
        self.rs_train_acc_per.append(train_acc)
        self.rs_train_loss_per.append(train_loss)
        self.rs_avg_acc_per.append(glob_acc_avg)
        if(self.experiment):
            self.experiment.log_metric("glob_acc",glob_acc)
            self.experiment.log_metric("train_acc",train_acc)
            self.experiment.log_metric("train_loss",train_loss)
            self.experiment.log_metric("glob_avg",glob_acc_avg)
        #print("stats_train[1]",stats_train[3][0])
        print("Average Personal Accurancy: ", glob_acc)
        print("Average Meta AVG Accurancy: ", glob_acc_avg)
        print("Average Personal Trainning Accurancy: ", train_acc)
        print("Average Personal Trainning Loss: ",train_loss)
        
        

    def meta_evaluate(self):
        stats = self.meta_test()  
        stats_train = self.meta_train_error_and_loss()
        glob_acc = np.sum(stats[2])*1.0/np.sum(stats[1])
        train_acc = np.sum(stats_train[2])*1.0/np.sum(stats_train[1])
        # train_loss = np.dot(stats_train[3], stats_train[1])*1.0/np.sum(stats_train[1])
        #train_loss = sum([x * y for (x, y) in zip(stats_train[1], stats_train[3])]).item() / np.sum(stats_train[1])
        train_loss = np.mean(list(stats_train[3]))
        self.rs_glob_acc.append(glob_acc)
        self.rs_train_acc.append(train_acc)
        self.rs_train_loss.append(train_loss)
        if(self.experiment):
            self.experiment.log_metric("glob_acc",glob_acc)
            self.experiment.log_metric("train_acc",train_acc)
            self.experiment.log_metric("train_loss",train_loss)
        #print("stats_train[1]",stats_train[3][0])
        print("Average Meta Accurancy: ", glob_acc)
        print("Average Meta Trainning Accurancy: ", train_acc)
        print("Average Meta Trainning Loss: ",train_loss)


    def meta_test(self):
        '''tests self.latest_model on given clients
        '''
        num_samples = []
        tot_correct = []
        losses = []
        for c in self.test_users:
            ct, ns, _ = c.test()
            tot_correct.append(ct*1.0)
            num_samples.append(ns)
        ids = [c.id for c in self.users]
        return ids, num_samples, tot_correct

    def meta_train_error_and_loss(self):
        num_samples = []
        tot_correct = []
        losses = []
        for c in self.test_users:
            ct, cl, ns = c.train_error_and_loss() 
            tot_correct.append(ct*1.0)
            num_samples.append(ns)
            losses.append(cl*1.0)
        ids = [c.id for c in self.users]
        #groups = [c.group for c in self.clients]

        return ids, num_samples, tot_correct, losses