import copy
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
from system.flcore.clients.clientpac import clientPAC
from system.flcore.servers.serverbase import Server
from system.flcore.clients.clientbase import load_item, save_item
from threading import Thread
from system.flcore.trainmodel.models import BaseHeadSplit
import cvxpy as cvx
import numpy as np
import math


class FedPAC(Server):
    def __init__(self, args, times):
        super().__init__(args, times)
        if args.save_folder_name == 'temp' or 'temp' not in args.save_folder_name:
            if hasattr(args, 'global_model'):
                global_model = BaseHeadSplit(
                    args,
                    feature_dim=args.sub_feature_dim,
                    is_global=True
                ).to(args.device)
            else:
                global_model = BaseHeadSplit(
                    args, 0, args.sub_feature_dim).to(args.device)
            save_item(global_model, self.role, 'global_model', self.save_folder_name)

        # select slow clients
        self.set_slow_clients()
        self.set_clients(clientPAC)

        self.selected_clients = self.clients
        for client in self.selected_clients:
            client.train()  # no DBE

        self.uploaded_ids = []
        self.uploaded_weights = []
        tot_samples = 0
        for client in self.selected_clients:
            tot_samples += client.train_samples
            self.uploaded_ids.append(client.id)
            self.uploaded_weights.append(client.train_samples)
        for i, w in enumerate(self.uploaded_weights):
            self.uploaded_weights[i] = w / tot_samples

        global_mean = 0
        global_mean_g = 0
        for cid, w in zip(self.uploaded_ids, self.uploaded_weights):
            global_mean += self.clients[cid].running_mean * w
            global_mean_g += self.clients[cid].running_mean_g * w
        print('>>>> global_mean <<<<', global_mean)
        print('>>>> global_mean_g <<<<', global_mean_g)
        for client in self.selected_clients:
            client.global_mean = global_mean.data.clone()
            client.global_mean_g = global_mean_g.data.clone()

        print(f"\nJoin ratio / total clients: {self.join_ratio} / {self.num_clients}")
        print("Finished creating server and clients.")

        # self.load_model()
        self.Budget = []
        print('featrue map shape: ', self.clients[0].client_mean.shape)
        print('featrue map numel: ', self.clients[0].client_mean.numel())
        print('featrue_g map shape: ', self.clients[0].client_mean_g.shape)
        print('featrue_g map numel: ', self.clients[0].client_mean_g.numel())

        self.global_class_centers = None

    def train(self):
        for i in range(self.global_rounds + 1):
            s_t = time.time()
            self.selected_clients = self.select_clients()

            if i % self.eval_gap == 0:
                print(f"\n-------------Round number: {i}-------------")
                print("\nEvaluate heterogeneous models")
                self.evaluate()

            # for client in self.selected_clients:
            #     client.train()

            threads = [Thread(target=client.train)
                       for client in self.selected_clients]
            [t.start() for t in threads]
            [t.join() for t in threads]

            self.receive_ids()
            self.aggregate_parameters()

            self.Budget.append(time.time() - s_t)
            print('-' * 25, 'time cost', '-' * 25, self.Budget[-1])

            if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt):
                break

        print("\nBest accuracy.")
        # self.print_(max(self.rs_test_acc), max(
        #     self.rs_train_acc), min(self.rs_train_loss))
        print(max(self.rs_test_acc))
        print("\nAverage time cost per round.")
        print(sum(self.Budget[1:]) / len(self.Budget[1:]))

        self.save_results()

    def receive_ids(self):
        assert (len(self.selected_clients) > 0)

        active_clients = random.sample(
            self.selected_clients, int((1 - self.client_drop_rate) * self.current_num_join_clients))

        self.uploaded_ids = []
        self.uploaded_weights = []
        tot_samples = 0
        for client in active_clients:
            tot_samples += client.train_samples
            self.uploaded_ids.append(client.id)
            self.uploaded_weights.append(client.train_samples)
        for i, w in enumerate(self.uploaded_weights):
            self.uploaded_weights[i] = w / tot_samples

    def aggregate_parameters(self):
        assert (len(self.uploaded_ids) > 0)

        global_model = load_item(self.role, 'global_model', self.save_folder_name).base
        for param in global_model.parameters():
            param.data.zero_()

        for cid, w in zip(self.uploaded_ids, self.uploaded_weights):
            client = self.clients[cid]
            client_model = load_item(client.role, 'global_model', client.save_folder_name)
            for server_param, client_param in zip(global_model.parameters(), client_model.base.parameters()):
                server_param.data += client_param.data.clone() * w

        save_item(global_model, self.role, 'global_model', self.save_folder_name)
