from __future__ import annotations

import copy
import datetime
import numpy as np
import sys
import time
import math
import torch
import torch.nn.functional as F
import models
from itertools import compress
from config import cfg
from collections import defaultdict

from _typing import (
    ModelType,
    ClientType,
    DatasetType
)

from utils.api import (
    to_device,  
    collate
)

from models.api import (
    create_model,
    make_batchnorm
)

from data import (
    fetch_dataset, 
    split_dataset, 
    make_data_loader, 
    separate_dataset, 
    make_batchnorm_dataset, 
    make_batchnorm_stats
)

from optimizer.api import create_optimizer

class ServerBase:

    def __init__(
        self,
        dataset
    ) -> None:
        # dataset is train dataset
        # self.train_batchnorm_dataset = make_batchnorm_dataset(dataset)
        self.test_model_perform_list = []

        self.best_server_model_state_dict = None
        self.best_server_optimizer_state_dict = None
        return
    
    def create_model(self, track_running_stats=False, on_cpu=False):
        return create_model(track_running_stats=track_running_stats, on_cpu=on_cpu)

    def create_test_model(
        self,
        model_state_dict,
        batchnorm_dataset
    ) -> object:

        model = create_model()
        model.load_state_dict(model_state_dict)
        test_model = make_batchnorm_stats(batchnorm_dataset, model, 'server')

        return test_model

    def distribute_server_model_to_clients(
        self,
        server_model_state_dict,
        clients
    ) -> None:

        model = self.create_model(track_running_stats=False)
        model.load_state_dict(server_model_state_dict)
        server_model_state_dict = {k: v.cpu() for k, v in model.state_dict().items()}
        for m in range(len(clients)):
            if clients[m].active:
                clients[m].model_state_dict = copy.deepcopy(server_model_state_dict)
        return

    def update_server_model(self, clients: dict[int, ClientType]) -> None:
        with torch.no_grad():
            valid_clients = [clients[i] for i in range(len(clients)) if clients[i].active]
            # print(f'valid_clients: {valid_clients}')
            if valid_clients:
                model = self.create_model(track_running_stats=False, on_cpu=True)
                model.load_state_dict(self.server_model_state_dict)
                server_optimizer = create_optimizer(model, 'server')
                server_optimizer.load_state_dict(self.server_optimizer_state_dict)
                server_optimizer.zero_grad()
                weight = torch.ones(len(valid_clients))
                if cfg['update_server_model'] == 'average':
                    weight = weight / weight.sum()
                for k, v in model.named_parameters():
                    parameter_type = k.split('.')[-1]
                    if 'weight' in parameter_type or 'bias' in parameter_type:
                        tmp_v = v.data.new_zeros(v.size())
                        for m in range(len(valid_clients)):
                            tmp_v += weight[m] * valid_clients[m].model_state_dict[k]

                        if cfg['update_server_model'] == 'average':
                            v.grad = (v.data - tmp_v).detach()
                        elif cfg['update_server_model'] == 'noaverage':
                            # a = copy.deepcopy(tmp_v)
                            # b = v.data
                            tmp_v -= (len(valid_clients)-1) * v.data
                            # c = copy.deepcopy(tmp_v)
                            v.grad = (v.data - tmp_v).detach()
                        else:
                            raise ValueError('update_server_model wrong')
                server_optimizer.step()
                self.server_optimizer_state_dict = server_optimizer.state_dict()
                self.server_model_state_dict = {k: v.cpu() for k, v in model.state_dict().items()}


            for i in range(len(clients)):
                clients[i].active = False
        return
    
    
    def distribute_new_server_model_to_participation_clients(
        self,
        server_model_state_dict,
        cur_round_partcipation_client_ids
    ):
        model = self.create_model(track_running_stats=False)
        model.load_state_dict(server_model_state_dict)
        server_model_state_dict = {k: v.cpu() for k, v in model.state_dict().items()}
        for client_id in cur_round_partcipation_client_ids:
            self.clients[client_id].model_state_dict = copy.deepcopy(server_model_state_dict)
        return


    def add_log(
        self,
        i,
        num_active_clients,
        start_time,
        global_epoch,
        lr,
        selected_client_ids,
        metric,
        logger
    ) -> None:
        if i % int((num_active_clients * cfg['log_interval']) + 1) == 0:
            _time = (time.time() - start_time) / (i + 1)
            global_epoch_finished_time = datetime.timedelta(seconds=_time * (num_active_clients - i - 1))
            exp_finished_time = global_epoch_finished_time + datetime.timedelta(
                seconds=round((cfg['server']['num_epochs'] - global_epoch) * _time * num_active_clients))
            exp_progress = 100. * i / num_active_clients
            info = {'info': [
                        'Model: {}'.format(cfg['model_tag']),
                        'Train Epoch (C): {}({:.0f}%)'.format(global_epoch, exp_progress),
                        'Learning rate: {:.6f}'.format(lr),
                        # 'ID: {}({}/{})'.format(selected_client_ids[i], i + 1, num_active_clients),
                        # 'ID: {}({}/{})'.format(selected_client_ids[i], i + 1, num_active_clients),
                        'Global Epoch Finished Time: {}'.format(global_epoch_finished_time),
                        'Experiment Finished Time: {}'.format(exp_finished_time),
                    ],
                    # 'selected_client_ids': selected_client_ids
            }
            logger.append(
                result=info, 
                tag='train', 
                mean=False
            )
            print(logger.write('train', metric.metric_name['train']), flush=True)

    def select_clients(
        self, clients: dict[int, ClientType], clients_indicator_list=None, global_epoch=None
    ) -> tuple[list[int], int]:
        if cfg['algo_mode'] == 'fedavg':
            num_active_clients = int(np.ceil(cfg['active_rate'] * cfg['num_clients']))
            selected_client_ids = torch.arange(cfg['num_clients'])[torch.randperm(cfg['num_clients'])[:num_active_clients]].tolist()
            selected_client_ids.sort()
            # rest_client_ids = list(set([i for i in range(cfg['num_clients'])]) - set(selected_client_ids))
            for i in range(num_active_clients):
                clients[selected_client_ids[i]].active = True
            return selected_client_ids, num_active_clients
        elif cfg['algo_mode'] == 'fedincen' or cfg['algo_mode'] == 'icl':
            # num_active_clients = max(int(np.ceil(cfg['active_rate'] * len(cur_round_participation_client_ids))), cfg['minimum_participation_clients'])
            cur_round_participation_client_ids = []
            for m in range(len(self.clients)):
                if global_epoch <= 5:
                    # first round, let all clients participate
                    # delta = 0 for all clients at round 1
                    cur_round_participation_client_ids.append(m)
                else:
                    if clients_indicator_list[m] == 1:
                        cur_round_participation_client_ids.append(m)

            if len(cur_round_participation_client_ids) < cfg['minimum_participation_clients']:
                supplementary_client_num = cfg['minimum_participation_clients'] - len(cur_round_participation_client_ids)

                all_client_ids = [i for i in range(cfg['num_clients'])]
                candidate_client_ids = list(set(all_client_ids) - set(cur_round_participation_client_ids))

                supplementary_client_ids = list(np.random.choice(candidate_client_ids, size=supplementary_client_num, replace=False))

                cur_round_participation_client_ids.extend(supplementary_client_ids)

            num_active_clients = int(np.ceil(cfg['active_rate'] * cfg['num_clients']))
            selected_client_ids = torch.as_tensor(cur_round_participation_client_ids)\
                [torch.randperm(len(cur_round_participation_client_ids))[:num_active_clients]].tolist()
            selected_client_ids.sort()
            # rest_client_ids = list(set([i for i in range(cfg['num_clients'])]) - set(selected_client_ids))
            for i in range(num_active_clients):
                clients[selected_client_ids[i]].active = True
        else:
            raise ValueError('algo_mode not supported')
        # print(f'selected_client_ids: {selected_client_ids}')
        return selected_client_ids, num_active_clients, cur_round_participation_client_ids
    
    def combine_test_dataset(
        self,
        num_active_clients: int,
        clients: dict[int, ClientType],
        selected_client_ids: list[int],
        dataset: DatasetType
    ) -> DatasetType:  
        '''
        combine the datapoint index for selected clients
        and return the dataset
        '''
        combined_datapoint_idx = []
        for i in range(num_active_clients):
            m = selected_client_ids[i]
            combined_datapoint_idx += clients[m].data_split['test']

        # dataset: DatasetType
        dataset = separate_dataset(dataset, combined_datapoint_idx)
        return dataset

    def initiate_server_collaboration_gains(
        self,
        dataset,
        logger,
        metric,
        server_model_state_dict,
    ):
        data_loader = make_data_loader(
            dataset={'test': dataset}, 
            tag='server'
        )['test']

        model = self.create_test_model(
            model_state_dict=server_model_state_dict,
            batchnorm_dataset=dataset,
        )
        logger.safe(True)
        with torch.no_grad():
            model.train(False)
            for i, input in enumerate(data_loader):

                input = collate(input)
                input_size = input['data'].size(0)
                input = to_device(input, cfg['device'])
                output = model(input)

                evaluation = metric.evaluate(
                    metric.metric_name['test'], 
                    input, 
                    output
                )
                logger.append(
                    evaluation, 
                    'test_server_initilization', 
                    input_size
                )

        test_server_initilization_loss = copy.deepcopy(logger.mean[f'test_server_initilization/Loss'])
        logger.safe(False)
        logger.reset()
        return test_server_initilization_loss

    def evaluate_trained_model(
        self,
        dataset,
        batchnorm_dataset,
        logger,
        metric,
        global_epoch,
        server_model_state_dict,
        clients,
        cur_round_active_client_ids=None
    ):  

        start_time = time.time()
        data_loader = make_data_loader(
            dataset={'test': dataset}, 
            tag='server'
        )['test']

        model = self.create_test_model(
            model_state_dict=server_model_state_dict,
            batchnorm_dataset=batchnorm_dataset
        )

        server_weight_collector = copy.deepcopy(list(model.parameters()))
        
        with torch.no_grad():
            model.train(False)
            for i, input in enumerate(data_loader):
                input = collate(input)
                input_size = input['data'].size(0)
                input = to_device(input, cfg['device'])
                output = model(input)

                evaluation = metric.evaluate(
                    metric.metric_name['test'], 
                    input, 
                    output
                )
                logger.append(
                    evaluation, 
                    'test_server', 
                    input_size
                )

            info = {
                'info': [
                    'Model: {}'.format(cfg['model_tag']), 
                    'Test Epoch: {}({:.0f}%)'.format(global_epoch, 100.)
                ]
            }
            logger.append(info, 'test_server', mean=False)
            print(logger.write('test_server', metric.metric_name['test']), flush=True)

        if cfg['algo_mode'] == 'fedavg':
            return

        model.train(True)
        optimizer = create_optimizer(model, 'client')
        server_gradient_collector = copy.deepcopy(list(model.parameters()))
        for param_index, param in enumerate(model.parameters()):
            server_gradient_collector[param_index] = copy.deepcopy(param.data.new_zeros(param.size()))
        for i, input in enumerate(data_loader):
            input = collate(input)
            input_size = input['data'].size(0)
            input = to_device(input, cfg['device'])
            optimizer.zero_grad()
            output = model(input)
            output['loss'].backward()

            for param_index, param in enumerate(model.parameters()):
                # print(f'server client_id: {client_id}, param_index: {param_index}, param.grad: {param.grad}')
                server_gradient_collector[param_index] += copy.deepcopy(param.grad)

        # evaluate_active_clients_individual_test_loss
        marginal_gain_second_terms = []
        for client_id in cur_round_active_client_ids:
            client = self.clients[client_id]
            # print('test_client_id: ', client_id)
            model = self.create_test_model(
                model_state_dict=copy.deepcopy(client.model_state_dict),
                batchnorm_dataset=batchnorm_dataset
            )
            optimizer = create_optimizer(model, 'client')
            client_weight_collector = copy.deepcopy(list(model.parameters()))
            model.train(True)
            for i, input in enumerate(data_loader):

                input = collate(input)
                input_size = input['data'].size(0)
                input = to_device(input, cfg['device'])
                optimizer.zero_grad()
                output = model(input)
                # print(f'client_id: {client_id}, output[loss]: {output["loss"]}')
                output['loss'].backward()


                evaluation = metric.evaluate(
                    metric.metric_name['test'], 
                    input, 
                    output
                )
                logger.append(
                    evaluation, 
                    f'test_{client_id}', 
                    input_size
                )

            marginal_gain_second_term = 0
            for param_index, param in enumerate(model.parameters()):
                marginal_gain_second_term += 1/len(cur_round_active_client_ids) * \
                    torch.sum(-server_gradient_collector[param_index] * \
                    (copy.deepcopy(server_weight_collector[param_index]) - client_weight_collector[param_index]))

            marginal_gain_second_terms.append(marginal_gain_second_term)

        end_time = time.time()
        print(f'{end_time} - {start_time} for testing')
        individual_losses = []
        for client_id in cur_round_active_client_ids:

            individual_losses.append(copy.deepcopy(logger.mean[f'test_{client_id}/Loss']))
        
        return True, logger.mean[f'test_server/Loss'], logger.mean[f'test_server/Loss'], individual_losses, marginal_gain_second_terms
