import os
import re
import copy
import random
import torch
import numpy as np
import logging
from sklearn.cluster import AgglomerativeClustering
from sklearn.metrics.pairwise import cosine_distances
from federatedscope.core.aggregators import ClientsAvgAggregator
from federatedscope.core.configs.config import global_cfg

logger = logging.getLogger(__name__)


class ATCAggregator(ClientsAvgAggregator):
    def __init__(self, model=None, config=None, device='cpu'):
        super().__init__(model=model, config=config, device=device)
        self.client_num = config.federate.client_num
        self.task = config.model.task

        self.pretrain_tasks = config.model.pretrain_tasks
        self.num_agg_groups = config.aggregator.num_agg_groups
        self.num_agg_topk = config.aggregator.num_agg_topk
        self.inside_weight = config.aggregator.inside_weight
        self.outside_weight = config.aggregator.outside_weight
        self.models = []
        self.neighbors = {}
        self.client_id2group = [None for _ in range(self.client_num)]
        self.client_id2topk = [[] for _ in range(self.client_num)]
        self.client_id2all = [[] for _ in range(self.client_num)]

        self.use_contrastive_loss = config.model.use_contrastive_loss
        if self.use_contrastive_loss:
            self.contrast_monitor = None

    def update_models(self, models):
        self.models = models

    def update_neighbors(self, neighbors):
        self.neighbors = neighbors

    def update_contrast_monitor(self, contrast_monitor):
        self.contrast_monitor = contrast_monitor

    def aggregate(self, agg_info):
        models = agg_info["client_feedback"]
        recover_fun = agg_info['recover_fun'] if (
                'recover_fun' in agg_info and global_cfg.federate.use_ss) \
            else None
        avg_models, tasks = self._para_weighted_avg(models,
                                                    recover_fun=recover_fun)
        return avg_models, tasks

    def update(self, model_parameters):
        for i, param in enumerate(model_parameters):
            self.models[i].load_state_dict(param, strict=False)

    def save_model(self, path, cur_round=-1):
        assert self.models is not None

        path = os.path.join(path, 'global')
        os.makedirs(path, exist_ok=True)
        neighbor_ids = sorted(list(self.neighbors.keys()))
        for i, model in enumerate(self.models):
            ckpt = {'cur_round': cur_round, 'model': model.state_dict()}
            torch.save(
                ckpt,
                os.path.join(path,
                             'global_model_{}.pt'.format(neighbor_ids[i])))

    def load_model(self, path):
        if getattr(self, 'models', None):
            round = None
            global_dir = os.path.join(path, 'global')
            client_dir = os.path.join(path, 'client')
            neighbor_ids = sorted([
                int(re.search(r'model_(\d+).pt', x).groups()[0])
                for x in os.listdir(global_dir)
            ])
            assert len(neighbor_ids) == len(self.models)

            for i, model in enumerate(self.models):
                cur_global_path = os.path.join(
                    global_dir, 'global_model_{}.pt'.format(neighbor_ids[i]))
                cur_client_path = os.path.join(
                    client_dir, 'client_model_{}.pt'.format(neighbor_ids[i]))
                if os.path.exists(cur_global_path):
                    model_ckpt = model.state_dict()
                    logger.info(
                        'Loading model from \'{}\''.format(cur_global_path))
                    global_ckpt = torch.load(cur_global_path,
                                             map_location=self.device)
                    model_ckpt.update(global_ckpt['model'])
                    if os.path.exists(cur_client_path):
                        logger.info('Updating model from \'{}\''.format(
                            cur_client_path))
                        client_ckpt = torch.load(cur_client_path,
                                                 map_location=self.device)
                        model_ckpt.update(client_ckpt['model'])
                    self.models[i].load_state_dict(model_ckpt)
                    round = global_ckpt['cur_round']
                else:
                    raise ValueError(
                        "The file {} does NOT exist".format(cur_global_path))

            return round

        else:
            return None

    def _compute_client_groups(self, models):
        tasks = [None for _ in range(self.client_num)]

        if self.task == 'pretrain':
            grads = torch.stack([
                torch.cat([g.view(-1) for g in model['model_grads'].values()])
                for model in models
            ])
            clustering = AgglomerativeClustering(
                n_clusters=self.num_agg_groups,
                affinity='cosine',
                linkage='average').fit(grads)
            self.client_id2group = clustering.labels_
            task_id = random.randint(0, len(self.pretrain_tasks) - 1)
            tasks = [
                self.pretrain_tasks[task_id] for _ in range(self.client_num)
            ]
        else:
            grads = torch.stack([
                torch.cat([g.view(-1) for g in model['model_grads'].values()])
                for model in models
            ])
            distances = cosine_distances(grads, grads)
            self.client_id2topk = [
                dis[:k].tolist() for dis, k in zip(
                    np.argsort(distances, axis=-1), self.num_agg_topk)
            ]
            self.client_id2all = np.argsort(distances, axis=-1).tolist()

        return tasks

    def _avg_params(self, models, client_adj_norm):
        avg_model = copy.deepcopy([{
            n: p
            for n, p in model.state_dict().items()
            if n in models[0]['model_grads']
        } for model in self.models])
        model_grads = copy.deepcopy([model['model_grads'] for model in models])
        avg_grads = copy.deepcopy(model_grads)
        for k in avg_model[0]:
            for i in range(len(avg_model)):
                for j in range(len(avg_model)):
                    weight = client_adj_norm[i][j]
                    local_grad = model_grads[j][k].float()
                    if j == 0:
                        avg_grads[i][k] = local_grad * weight
                    else:
                        avg_grads[i][k] += local_grad * weight
                avg_model[i][k] = avg_model[i][k].float() + avg_grads[i][k]

        return avg_model

    def _para_weighted_avg(self, models, recover_fun=None):
        tasks = [None for _ in range(self.client_num)]
        if self.cfg.federate.method in ['local', 'global']:
            model_params = {
                'model_para': [model['model_para'] for model in models]
            }
            return model_params, tasks

        if self.task == 'pretrain':
            # generate self.client_id2group and param weight matrix
            tasks = self._compute_client_groups(models)
            group_id2client = {k: [] for k in range(self.num_agg_groups)}
            for gid in range(self.num_agg_groups):
                for cid in range(self.client_num):
                    if self.client_id2group[cid] == gid:
                        group_id2client[gid].append(cid)
            logger.info('group_id2client: {}'.format({
                k + 1: [x + 1 for x in v]
                for k, v in group_id2client.items()
            }))

            client_adj = torch.zeros(self.client_num, self.client_num)
            for i in range(self.client_num):
                for j in range(self.client_num):
                    if self.client_id2group[i] == self.client_id2group[j]:
                        client_adj[i][j] = models[j]['sample_size'] * \
                                           self.inside_weight
                    else:
                        client_adj[i][j] = models[j]['sample_size'] * \
                                           self.outside_weight
            client_adj_norm = client_adj / client_adj.sum(dim=-1, keepdim=True)

            # aggregate model params
            if not self.use_contrastive_loss:
                model_params = {
                    'model_para': self._avg_params(models, client_adj_norm),
                }
            else:
                model_params = {
                    'model_para': self._avg_params(models, client_adj_norm),
                    'contrast_monitor': self.contrast_monitor
                }

        else:
            if not self.use_contrastive_loss:
                # generate self.client_id2topk and param weight matrix
                tasks = self._compute_client_groups(models)
                logger.info('client_id2topk: {}'.format({
                    k + 1: [x + 1 for x in v] if v else v
                    for k, v in enumerate(self.client_id2topk)
                }))

                client_adj = torch.zeros(self.client_num, self.client_num)
                for i in range(self.client_num):
                    for j in range(self.client_num):
                        if j in self.client_id2topk[i]:
                            client_adj[i][j] = models[j]['sample_size'] * \
                                               self.inside_weight
                        else:
                            client_adj[i][j] = models[j]['sample_size'] * \
                                               self.outside_weight
                client_adj_norm = client_adj / client_adj.sum(dim=-1,
                                                              keepdim=True)

                # aggregate model params
                model_params = {
                    'model_para': self._avg_params(models, client_adj_norm)
                }

            else:
                contrast_stat = models[0]['contrast_monitor'].stat
                for model in models:
                    assert model['contrast_monitor'].stat == contrast_stat
                self.contrast_monitor.update_stat(contrast_stat)
                model_params = None

                if contrast_stat == 2:
                    dec_hidden = [
                        model['contrast_monitor'].dec_hidden
                        for model in models
                    ]
                    dec_out = [
                        model['contrast_monitor'].dec_out for model in models
                    ]
                    dec_hidden = {k + 1: v for k, v in enumerate(dec_hidden)}
                    dec_out = {k + 1: v for k, v in enumerate(dec_out)}
                    all_group_ids = {
                        k + 1: [x + 1 for x in v]
                        for k, v in enumerate(self.client_id2all)
                    }
                    topk_group_ids = {
                        k + 1: [x + 1 for x in v]
                        for k, v in enumerate(self.client_id2topk)
                    }
                    self.contrast_monitor.update_dec_hidden(dec_hidden)
                    self.contrast_monitor.update_dec_out(dec_out)
                    self.contrast_monitor.update_all_group_ids(all_group_ids)
                    self.contrast_monitor.update_topk_group_ids(topk_group_ids)

                elif contrast_stat == 3:
                    # generate self.client_id2topk and param weight matrix
                    tasks = self._compute_client_groups(models)
                    logger.info('client_id2all (n_topk={}): {}'.format(
                        self.num_agg_topk, {
                            k + 1: [x + 1 for x in v]
                            for k, v in enumerate(self.client_id2all)
                        }))

                    client_adj = torch.zeros(self.client_num, self.client_num)
                    for i in range(self.client_num):
                        for j in range(self.client_num):
                            if j in self.client_id2topk[i]:
                                client_adj[i][j] = models[j]['sample_size'] * \
                                                   self.inside_weight
                            else:
                                client_adj[i][j] = models[j]['sample_size'] * \
                                                   self.outside_weight
                    client_adj_norm = client_adj / client_adj.sum(dim=-1,
                                                                  keepdim=True)

                    # aggregate model params
                    model_params = {
                        'model_para': self._avg_params(models, client_adj_norm)
                    }

        return model_params, tasks
