import copy
import logging
import random
import numpy as np
import pandas as pd
import torch
import math
import copy
from math import sqrt
from utils import transform_list_to_tensor

from client import Client


class FRJVE(object):
    def __init__(self, dataset, device, args, model_trainer):
        self.device = device
        self.args = args
        [client_data_all, train_data_num, test_data_num, train_data_global, test_data_global,
         train_data_local_num_dict, train_data_local_dict, test_data_local_dict, class_num] = dataset
        self.client_all = client_data_all
        self.train_global = train_data_global
        self.test_global = test_data_global
        self.train_data_num_in_total = train_data_num
        self.test_data_num_in_total = test_data_num
        self.client_indexes = []
        self.client_list = []
        self.train_data_local_num_dict = train_data_local_num_dict
        self.train_rmse = []
        self.test_rmse = []
        self.train_mae = []
        self.test_mae = []
        self.train_mse = []
        self.test_mse = []
        self.model_trainer = model_trainer
        self.results = {'frjve_mae': 10, 'frjve_rmse': 10}

        self._setup_clients(train_data_local_num_dict, model_trainer)

    def _setup_clients(self, train_data_local_num_dict, model_trainer):
        logging.info("############setup_clients (START)#############")
        for client_idx in range(self.args.client_num_in_total):
            c = Client(client_idx, train_data_local_num_dict[client_idx], self.args, self.device, copy.deepcopy(model_trainer))
            self.client_list.append(c)
        logging.info("############setup_clients (END)#############")

    def update_results(self, mae, rmse, phase):
        if mae < self.results[phase + '_mae']:
            self.results[phase + '_mae'] = mae
        if rmse < self.results[phase + '_rmse']:
            self.results[phase + '_rmse'] = rmse

    def train(self):
        for round_idx in range(self.args.comm_round):
            w_global = self.model_trainer.get_model_params()
            
            logging.info("################Communication round : {}".format(round_idx))
            w_locals = []
            self._client_sampling(round_idx, self.args.client_num_in_total,
                                  self.args.client_num_per_round)
            logging.info("client_indexes = " + str(self.client_indexes))

            client_idx_all = list(range(self.args.client_num_in_total))

            for idx in self.client_indexes:
                client_idx = idx
                client_idx_all = list(range(self.args.client_num_in_total))
                client_idx_all.remove(client_idx)
                for i in self.client_list:
                    if i.client_idx == client_idx:
                        client = i

                w_client_global = client.model_trainer.get_model_params()
                keys_to_aggregate = [
                    'src_model.iid_embedding.weight',
                    'tgt_model.iid_embedding.weight'
                ]
                for k in keys_to_aggregate:
                    w_client_global[k] = copy.deepcopy(w_global[k])
                weight = client.train(copy.deepcopy(w_client_global), client_idx, client_idx_all)
                w_locals.append((client.local_sample_number, copy.deepcopy(weight)))

            w_global = self._aggregate(w_locals)

            if (round_idx + 1) % 1 == 0:
                self._local_test_on_all_clients(round_idx)

    def _client_sampling(self, round_idx, client_num_in_total, client_num_per_round):
        if client_num_in_total == client_num_per_round:
            self.client_indexes = [client_index for client_index in range(client_num_in_total)]
        else:
            num_clients = min(client_num_per_round, client_num_in_total)
            np.random.seed(round_idx)  # make sure for each comparison, we are selecting the same clients each round
            self.client_indexes = np.random.choice(range(client_num_in_total), num_clients, replace=False)

    def _aggregate(self, w_locals):
        training_num = 0
        for idx in range(len(w_locals)):
            (sample_num, averaged_params) = w_locals[idx]
            training_num += sample_num

        keys_to_aggregate = [
            'src_model.iid_embedding.weight',
            'tgt_model.iid_embedding.weight'
        ]
        (sample_num, averaged_params) = w_locals[0]

        for k in keys_to_aggregate:
            for i in range(0, len(w_locals)):
                local_sample_number, local_model_params = w_locals[i]
                w = local_sample_number / training_num
                if i == 0:
                    averaged_params[k] = local_model_params[k] * w
                else:
                    averaged_params[k] += local_model_params[k] * w

        global_model_params = averaged_params
        self.model_trainer.set_model_params(global_model_params)

        return global_model_params

    def _local_test_on_all_clients(self, round_idx):

        logging.info("################local_test_on_all_clients : {}".format(round_idx))

        test_metrics_frjve = {
            'num_samples': [],
            'losses': [],
            'mae': [],
        }

        for client in self.client_list:
            client_idx_all = list(range(self.args.client_num_in_total))
            client_idx_all.remove(client.client_idx)

            test_local_metrics_frjve = client.local_test(True, client.client_idx, client_idx_all)

            test_metrics_frjve['num_samples'].append(copy.deepcopy(test_local_metrics_frjve['test_total']))
            test_metrics_frjve['losses'].append(copy.deepcopy(test_local_metrics_frjve['test_loss']))
            test_metrics_frjve['mae'].append(copy.deepcopy(test_local_metrics_frjve['test_mae']))

        test_loss_frjve = math.sqrt(sum(test_metrics_frjve['losses']) / sum(test_metrics_frjve['num_samples']))
        test_mae_frjve = sum(test_metrics_frjve['mae']) / sum(test_metrics_frjve['num_samples'])
        self.update_results(test_mae_frjve, test_loss_frjve, 'frjve')

        logging.info("######################################")
        stats = {'test_mae_frjve': test_mae_frjve,
                 'test_rmse_frjve': test_loss_frjve}

        logging.info(stats)
        logging.info("######################################")
        logging.info("**************************************")
        logging.info(self.results)
        logging.info("**************************************")
