import time
from flcore.clients.clientmp import clientMP
from flcore.servers.serverbase import Server
from threading import Thread
import numpy as np
from utils.utils_spectral_dataloader import *


class FedMP(Server):
    def __init__(self, args):
        super().__init__(args)

        self.set_slow_clients()
        self.set_clients(args, clientMP)

        print(f"\nJoin ratio / total clients: {self.join_ratio} / {self.num_clients}")
        print("Finished creating server and clients.")

        self.Budget = []


    def train(self):

        if self.args.PTP:
            self.set_pretrained_client_backbones()

        for i in range(self.args.last_train_prompt + 1, self.args.global_rounds + 1):
            print(f"\n------------------------------------Round number: {i}------------------------------------\n")
            self.selected_clients = self.select_clients()
            self.send_models()

            if i == self.args.last_train_prompt + 1:
                trn_mode = 'WARMUP+ALT'
            else:
                trn_mode = 'ALT'


            clients_start = time.time()
            for client in self.selected_clients:
                client.train_test(global_iter=i, trn_mode=trn_mode)
            clients_end = time.time()

            msg= 'current round %d, time cost%f'%(i, clients_end-clients_start)
            print(msg)

            self.receive_models()
            self.aggregate_parameters()

            if i%self.eval_gap == 0:

                print("\nEvaluate global model")
                psnr_mean_M = self.evaluate_MPT(glob_iter=i)
                self.checkpoint_global_clients(glob_iter=i, psnr_mean_M=psnr_mean_M)


    def test(self):
        assert self.args.test_mode, 'MUST Specify [test_mode]!'

        if self.args.PTP:
            self.set_pretrained_client_backbones()

        for i in range(self.args.last_train_prompt + 1, self.args.last_train_prompt + 2):
            print(f"\n------------------------------------Round number: {i}------------------------------------\n")
            self.selected_clients = self.select_clients()
            self.send_models()

            print("\nEvaluate global model")
            self.evaluate_MPT(glob_iter=i)

    def evaluate_MPT(self, glob_iter):

        if self.args.mask_op == 'fixed256':
            raise NotImplementedError

        elif self.args.mask_op == 'rand_crop':
            PSNR_c_union, SSIM_c_union = [], []
            for c in self.clients:
                (_, _, _, _, psnr_c_ls_u, ssim_c_ls_u) = test_Mtrials_MPT(args=self.args,
                                                                              epoch=glob_iter,
                                                                              model_path=self.args.model_path,
                                                                              net=c.backbone,
                                                                              prompt_net=self.global_model,
                                                                              test_data=self.args.test_data,
                                                                              mask4d_ls=self.args.mask4d_ls,
                                                                              mask_source='usr_union',
                                                                              id=self.args.num_clients,
                                                                              stay_log=False)
                PSNR_c_union.append(psnr_c_ls_u)
                SSIM_c_union.append(ssim_c_ls_u)

            psnr_mean_M = np.mean(np.array(PSNR_c_union))
            msg = '===>mask:usr_union, trials:{}, Epoch {}: testing psnr = {:.5f}/{:5f}(tsa), ssim = {:.5f}/{:5f}'.format(
                self.args.trial_num,
                glob_iter,
                psnr_mean_M,
                np.std(np.array(PSNR_c_union)),
                np.mean(np.array(SSIM_c_union)),
                np.std(np.array(SSIM_c_union)))
            gen_log(model_path=self.args.model_path, msg=msg, user_id=self.args.num_clients)
            PSNR_c_assign, SSIM_c_assign = [], []
            for c in self.clients:
                _, _, _, _, psnr_c_ls, ssim_c_ls = test_Mtrials_MPT(args=self.args,
                                                                        epoch=glob_iter,
                                                                        model_path=self.args.model_path,
                                                                        prompt_net=self.global_model,
                                                                        net=c.backbone,
                                                                        test_data=self.args.test_data,
                                                                        mask4d_ls=self.args.mask4d_ls,
                                                                        mask_source='assign_usr',
                                                                        for_client=True,
                                                                        id=c.id,
                                                                        stay_log=False)

                PSNR_c_assign.append(psnr_c_ls)
                SSIM_c_assign.append(ssim_c_ls)
            msg = '===>mask:assign_usr, trials:{}, Epoch {}: testing psnr = {:.5f}/{:5f}(tsa), ssim = {:.5f}/{:5f}'.format(
                self.args.trial_num,
                glob_iter,
                np.mean(np.array(PSNR_c_assign)),
                np.std(np.array(PSNR_c_assign)),
                np.mean(np.array(SSIM_c_assign)),
                np.std(np.array(SSIM_c_assign)))
            gen_log(model_path=self.args.model_path, msg=msg, user_id=self.args.num_clients)


        return psnr_mean_M