import os
import json
import logging
import copy
import torch
import numpy as np
from federatedscope.core.message import Message
from federatedscope.core.workers import Server
from federatedscope.nlp.hetero_tasks.trainer.utils import ContrastiveMonitor
from federatedscope.nlp.hetero_tasks.dataset.utils import load_synth_data

logger = logging.getLogger(__name__)


class ATCServer(Server):
    def __init__(self,
                 ID=-1,
                 state=0,
                 config=None,
                 data=None,
                 model=None,
                 client_num=5,
                 total_round_num=10,
                 device='cpu',
                 strategy=None,
                 unseen_clients_id=None,
                 **kwargs):

        super().__init__(ID=ID,
                         state=state,
                         config=config,
                         data=data,
                         model=model,
                         client_num=client_num,
                         total_round_num=total_round_num,
                         device=device,
                         strategy=strategy,
                         unseen_clients_id=unseen_clients_id,
                         **kwargs)

        # multiple models are maintained for different clients
        self.models = [
            copy.deepcopy(self.model) for _ in range(self.client_num)
        ]
        self.tasks = [
            config.model.pretrain_tasks[0]
            if config.model.pretrain_tasks else None
            for _ in range(self.client_num)
        ]
        self.atc_vanilla = config.federate.atc_vanilla
        if not self.atc_vanilla:
            self.aggregator.update_models(self.models)
            self.aggregator.update_neighbors(self.comm_manager.neighbors)

        self.use_contrastive_loss = self._cfg.model.use_contrastive_loss
        if self._cfg.model.stage == 'contrast':
            # load synthetic for contrastive learning
            synth_feats, synth_toks = load_synth_data(self._cfg.data)
            self.contrast_monitor = ContrastiveMonitor()
            self.contrast_monitor.update_enc_hidden(synth_feats)
            self.contrast_monitor.update_synth_tokens(synth_toks)
            self.aggregator.update_contrast_monitor(self.contrast_monitor)

    def _perform_federated_aggregation(self):
        train_msg_buffer = dict(
            sorted(self.msg_buffer['train'][self.state].items(),
                   key=lambda x: x[0]))
        msg_list = list()
        for client_id in train_msg_buffer:
            msg_list.append(train_msg_buffer[client_id])

        # Aggregate
        aggregated_num = len(msg_list)
        if self.atc_vanilla:
            agg_info = {
                'client_feedback': [[x['sample_size'], x['model_para']]
                                    for x in msg_list],
                'recover_fun': self.recover_fun,
            }
            avg_models = self.aggregator.aggregate(agg_info)
            tasks = [None for _ in range(self.client_num)]
            for i in range(self.client_num):
                self.models[i].load_state_dict(avg_models, strict=False)
        else:
            agg_info = {
                'client_feedback': msg_list,
                'recover_fun': self.recover_fun,
            }
            avg_models, tasks = self.aggregator.aggregate(agg_info)
            if avg_models is not None and 'model_para' in avg_models:
                for i in range(self.client_num):
                    self.models[i].load_state_dict(avg_models['model_para'][i],
                                                   strict=False)
        self.tasks = tasks

        if self.use_contrastive_loss:
            if self._cfg.model.task != 'pretrain' and \
                    self.contrast_monitor.stat == 2:
                self.msg_buffer['train'][self.state].clear()
                self.broadcast_model_para(
                    msg_type='model_para',
                    sample_client_num=self.sample_client_num)
                return -1
            if self.contrast_monitor.stat == 3:
                self.contrast_monitor.reset()

        return aggregated_num

    def broadcast_model_para(self,
                             msg_type='model_para',
                             sample_client_num=-1,
                             filter_unseen_clients=True):
        if filter_unseen_clients:
            self.sampler.change_state(self.unseen_clients_id, 'unseen')

        if sample_client_num > 0:
            sample_ids = np.random.choice(np.arange(self.client_num),
                                          size=sample_client_num,
                                          replace=False).tolist()
        else:
            sample_ids = list(range(self.client_num))

        receivers = sorted(list(self.comm_manager.neighbors.keys()))
        model_para = [model.state_dict() for model in self.models]
        skip_broadcast = self._cfg.federate.method in ['local', 'global']
        if skip_broadcast:
            model_para = [{} for _ in self.models]

        for i in sample_ids:
            if not self.use_contrastive_loss:
                content = {
                    'model_para': model_para[i],
                    'task': self.tasks[i],
                }
            else:
                content = {
                    'model_para': model_para[i],
                    'task': self.tasks[i],
                    'contrast_monitor': self.contrast_monitor,
                }
            self.comm_manager.send(
                Message(msg_type=msg_type,
                        sender=self.ID,
                        receiver=receivers[i],
                        state=self.state,
                        content=content))

        if filter_unseen_clients:
            self.sampler.change_state(self.unseen_clients_id, 'seen')

    def merge_eval_results_from_all_clients(self, final_round=False):
        state = self.state if not final_round else self.state - 1
        eval_msg_buffer = self.msg_buffer['eval'][state]

        if 'group_avg' in self._cfg.eval.report:
            metrics_all_clients = eval_msg_buffer
        else:
            metrics_all_clients = dict()
            for each_client in eval_msg_buffer:
                client_eval_results = eval_msg_buffer[each_client]
                for key in client_eval_results.keys():
                    res = client_eval_results[key]
                    if isinstance(res, dict):
                        for k, v in res.items():
                            cur_key = key + '_' + k
                            if key not in metrics_all_clients:
                                metrics_all_clients[cur_key] = list()
                            metrics_all_clients[cur_key].append(float(v))
                    else:
                        if key not in metrics_all_clients:
                            metrics_all_clients[key] = list()
                        metrics_all_clients[key].append(float(res))
        formatted_logs = self._monitor.format_eval_res(
            metrics_all_clients,
            rnd=self.state + 1,
            role='Server #',
            forms=self._cfg.eval.report)
        logger.info(formatted_logs)
        self._monitor.save_formatted_results(formatted_logs)
        return formatted_logs
