import copy
import torch
import torch.nn as nn
import numpy as np

from tqdm import tqdm
from src.client import Client
from utils.get_dataset import get_dataset
from utils.log_utils import Logger, client_sampling, VariableMonitor
from torch.utils.data import DataLoader, Dataset


class Server:
    def __init__(self, device, local_model, args, logger=None, tensorboardLogger=None):
        # base parameters
        self.device = device
        self.args = args
        self.get_global_dataset(self.args)
        self.total_clients = self.args.num_users
        self.indexes = [i for i in range(self.total_clients)]
        self.logger = logger
        self.tensorboardLogger = tensorboardLogger

        if self.args.prompt_num_tokens_g > 0:
            global_prompt = nn.Parameter(torch.randn(1, self.args.prompt_num_tokens_g, args.feature_dim))
            torch.nn.init.xavier_uniform(global_prompt, gain=1)
        else:
            global_prompt = None

        # initialize clients
        self.clients = [Client(device=device, local_model=copy.deepcopy(local_model), global_prompt=copy.deepcopy(global_prompt), train_dataset=self.train_dataset,
                               test_dataset=self.test_dataset, train_idxs=self.train_user_groups[idx],
                               test_idxs=self.test_user_groups[idx], args=args, index=idx, logger=logger) for idx in self.indexes]


        self.best_accuracy = 0  # record the best testing accuracy during training
        self.best_accuracy_each_client = []

        self.test_dataloader = DataLoader(self.test_dataset, batch_size=100, shuffle=True)
        self.best_accuracy_server = 0



    def get_global_dataset(self, args):
        self.train_dataset, self.test_dataset, self.train_user_groups, self.test_user_groups = get_dataset(args)


    def average_weights(self):
        temp = copy.deepcopy(self.clients[0].moco)
        used_keys = []
        for key, param in temp.named_parameters():
            if 'queue' in key:
                continue
            if 'encoder_k' in key:
                continue
            if 'fc_k' in key:
                continue
            if self.args.personalized_attention == 1:
                if 'attention' in key:
                    continue
            if self.args.personalized_g_prompt == 1:
                if 'prompt' in key:
                    continue
            used_keys.append(key)
        w_avg = {k: temp.state_dict()[k] for k in used_keys}
        for key in w_avg.keys():
            for client in range(self.args.num_users):
                if client == 0: continue
                w_avg[key] += self.clients[client].moco.state_dict()[key]
            w_avg[key] = torch.div(w_avg[key], float(self.args.num_users))
        return w_avg


    def average_weight_classifier(self):
        w_avg = copy.deepcopy(self.clients[0].classifier.state_dict())
        for key in w_avg.keys():
            for client in range(self.args.num_users):
                if client == 0: continue
                w_avg[key] += self.clients[client].classifier.state_dict()[key]
            w_avg[key] = torch.div(w_avg[key], float(self.args.num_users))
        return w_avg

    def send_parameters(self):
        if self.args.policy == 1:   # separate training
            return
        elif self.args.policy == 2:
            # 1. aggregate moco
            w_avg = self.average_weights()
            for client in range(self.args.num_users):
                state_dict = self.clients[client].moco.state_dict()
                state_dict.update(w_avg)
                self.clients[client].moco.load_state_dict(state_dict)
            # 2. aggregate classifier
            if self.args.personalized_classifier == 0:
                w_avg = self.average_weight_classifier()
                for client in range(self.args.num_users):
                    self.clients[client].classifier.load_state_dict(w_avg)
            return
        else:
            return



    def train(self):
        train_losses_contrastive = []
        train_losses_classification = []
        train_acc = []

        train_losses_classification_inference = []
        train_acc_inference = []
        test_losses_classification_contrastive = []
        test_acc_contrastive = []
        test_losses = []
        test_acc = []
        for epoch in tqdm(range(self.args.epochs)):
            print(f'Start Training round: {epoch}')
            self.logger.info(f'Start Training round: {epoch}')

            local_train_losses_contrastive = []
            local_train_losses_classification = []
            local_train_acc = []

            local_train_losses_classification_inference = []
            local_train_acc_inference = []

            local_test_losses_classification_contrastive = []
            local_test_acc_contrastive = []

            local_test_losses = []
            local_test_acc = []


            # select clients to train their local model
            idxs = np.random.choice(self.indexes, max(int(self.args.frac * self.total_clients), 1), replace=False)
            for client in idxs:
                self.clients[client].train()
                local_train_losses_contrastive.append(self.clients[client].trainingLoss_contrastive)
                local_train_losses_classification.append(self.clients[client].trainingLoss_classification)
                local_train_acc.append(self.clients[client].trainAcc)


            local_train_losses_avg_contrastive = sum(local_train_losses_contrastive) / len(local_train_losses_contrastive)
            local_train_losses_avg_classification = sum(local_train_losses_classification) / len(local_train_losses_classification)
            local_train_acc_avg = sum(local_train_acc) / len(local_train_acc)
            train_losses_contrastive.append(local_train_losses_avg_contrastive)
            train_losses_classification.append(local_train_losses_avg_classification)
            train_acc.append(local_train_acc_avg)

            evaluate_before_send = (self.args.ft == 1)


            if not evaluate_before_send:
                self.send_parameters()

            # test on each client
            for client in range(self.args.num_users):
                acc, loss = self.clients[client].inference()
                local_test_acc.append(copy.deepcopy(acc))
                local_test_losses.append(copy.deepcopy(loss))

                acc, loss = self.clients[client].inference_train()
                local_train_acc_inference.append(copy.deepcopy(acc))
                local_train_losses_classification_inference.append(copy.deepcopy(loss))

                acc, loss = self.clients[client].inference_contrastive()
                local_test_acc_contrastive.append(copy.deepcopy(acc))
                local_test_losses_classification_contrastive.append(copy.deepcopy(loss))


            test_losses.append(sum(local_test_losses)/len(local_test_losses))
            test_acc.append(sum(local_test_acc)/len(local_test_acc))

            train_losses_classification_inference.append(sum(local_train_losses_classification_inference)/len(local_train_losses_classification_inference))
            train_acc_inference.append(sum(local_train_acc_inference)/len(local_train_acc_inference))

            test_losses_classification_contrastive.append(sum(local_test_losses_classification_contrastive)/len(local_test_losses_classification_contrastive))
            test_acc_contrastive.append(sum(local_test_acc_contrastive)/len(local_test_acc_contrastive))



            # update the best accuracy
            if test_acc[-1] >= self.best_accuracy:
                self.best_accuracy = test_acc[-1]
                self.best_accuracy_each_client = local_test_acc
                if self.args.save_model != 0:
                    self.save_model()

            if evaluate_before_send:
                self.send_parameters()

            # test on balance server test dataset
            acc_server, loss_server = self.clients[0].inference_specific_dataset(self.test_dataloader, prompt=None)
            if acc_server > self.best_accuracy_server:
                self.best_accuracy_server = acc_server

            # print the training information in this epoch
            print(f'Communication Round: {epoch}   Policy: {self.args.policy}')
            print(f'Constrastive training Loss for each client: {local_train_losses_contrastive}')
            print(f'Classification training Loss for each client: {local_train_losses_classification}')
            print(f'Testing Loss for each client: {local_test_losses}')
            print(f'Testing Acc for each client: {local_test_acc}')
            print(f'Each clients accuracy when best avg acc: {self.best_accuracy_each_client}')
            print(f'Avg constrative training Loss: {train_losses_contrastive[-1]}')
            print(f'Avg classification training Loss: {train_losses_classification[-1]}')
            print(f'Avg training Accuracy: {train_acc[-1]}')
            print(f'Avg training Loss Inference: {train_losses_classification_inference[-1]}')
            print(f'Avg training Accuracy Inference: {train_acc_inference[-1]}')
            print(f'Avg testing contrastive Loss: {test_losses_classification_contrastive[-1]}')
            print(f'Avg testing contrastive Accuracy: {test_acc_contrastive[-1]}')
            print(f'Avg testing Loss: {test_losses[-1]}')
            print(f'Avg testing Accuracy: {test_acc[-1]}')
            print(f'Test accuracy on server: {acc_server}, Best accuracy on server: {self.best_accuracy_server}')
            print(f'Best Accuracy up to now: {self.best_accuracy}')


            self.logger.info(f'Communication Round: {epoch}   Policy: {self.args.policy}')
            self.logger.info(f'Constrastive training Loss for each client: {local_train_losses_contrastive}')
            self.logger.info(f'Classification training Loss for each client: {local_train_losses_classification}')
            self.logger.info(f'Testing Loss for each client: {local_test_losses}')
            self.logger.info(f'Testing Acc for each client: {local_test_acc}')
            self.logger.info(f'Each clients accuracy when best avg acc: {self.best_accuracy_each_client}')
            self.logger.info(f'Avg constrative training Loss: {train_losses_contrastive[-1]}')
            self.logger.info(f'Avg classification training Loss: {train_losses_classification[-1]}')
            self.logger.info(f'Avg training Accuracy: {train_acc[-1]}')
            self.logger.info(f'Avg training Loss Inference: {train_losses_classification_inference[-1]}')
            self.logger.info(f'Avg training Accuracy Inference: {train_acc_inference[-1]}')
            self.logger.info(f'Avg testing contrastive Loss: {test_losses_classification_contrastive[-1]}')
            self.logger.info(f'Avg testing contrastive Accuracy: {test_acc_contrastive[-1]}')
            self.logger.info(f'Avg testing Loss: {test_losses[-1]}')
            self.logger.info(f'Avg testing Accuracy: {test_acc[-1]}')
            self.logger.info(f'Test accuracy on server: {acc_server}, Best accuracy on server: {self.best_accuracy_server}')
            self.logger.info(f'Best Accuracy up to now: {self.best_accuracy}')

            train_info = {
                'train_losses_contrastive': train_losses_contrastive[-1],
                'train_losses_classification': train_losses_classification[-1],
                'train_acc': train_acc[-1],
                'train_acc_inference': train_acc_inference[-1],
            }
            self.tensorboardLogger.add_scalars_dict(prefix='train', dic=train_info, rnd=epoch)

            test_info = {
                'test_losses': test_losses[-1],
                'test_acc': test_acc[-1],
                'contrastive_finetune': test_acc_contrastive[-1],
                'test_acc_server': acc_server,
                'best_acc_server': self.best_accuracy_server,
                'best_acc': self.best_accuracy
            }
            self.tensorboardLogger.add_scalars_dict(prefix='test', dic=test_info, rnd=epoch)

        return

    def save_model(self):
        for client in range(self.args.num_users):
            self.clients[client].save_model()
        return





