from __future__ import annotations

import copy
import datetime
import numpy as np
import sys
import math
import time
import torch
import torch.nn.functional as F
import models
from itertools import compress
from config import cfg
from collections import defaultdict

from _typing import (
    DatasetType,
    OptimizerType,
    DataLoaderType,
    ModelType,
    MetricType,
    LoggerType,
    ClientType,
    ServerType
)

from optimizer.api import create_optimizer
from .serverBase import ServerBase

from data import (
    fetch_dataset, 
    split_dataset, 
    make_data_loader, 
    separate_dataset, 
    make_batchnorm_dataset, 
    make_batchnorm_stats
)


class ServerFedAvg(ServerBase):

    def __init__(
        self, 
        model: ModelType,
        clients: dict[int, ClientType],
        dataset: DatasetType,
        # test_dataset: DatasetType,
    ) -> None:

        super().__init__(dataset=dataset)
        self.server_model_state_dict = {k: v.cpu() for k, v in model.state_dict().items()}
        server_optimizer = create_optimizer(model, 'server')
        self.server_optimizer_state_dict = server_optimizer.state_dict()
        self.clients = clients

    def train(
        self,
        dataset: DatasetType,  
        optimizer: OptimizerType, 
        metric: MetricType, 
        logger: LoggerType, 
        global_epoch: int,
        malicious_client_ids
    ):
        # update theta

        selected_client_ids, num_active_clients = super().select_clients(
            clients=self.clients, 
        )

        super().distribute_server_model_to_clients(
            server_model_state_dict=self.server_model_state_dict,
            clients=self.clients
        )
        start_time = time.time()
        lr = optimizer.param_groups[0]['lr']

        dataset_list = []
        for i in range(cfg['num_clients']):
            dataset_list.append(separate_dataset(dataset, self.clients[i].data_split['train']))

        processed_client_count = 0
        for i in range(num_active_clients):
            m = selected_client_ids[i]
            # dataset_m = separate_dataset(dataset, self.clients[m].data_split['train'])
            dataset_m = copy.deepcopy(dataset_list[m])
            if dataset_m is None:
                self.clients[m].active = False
            else:
                self.clients[m].active = True
                self.clients[m].train(
                    dataset=dataset_m, 
                    lr=lr, 
                    metric=metric, 
                    logger=logger,
                    malicious_client_ids=malicious_client_ids
                )
            processed_client_count += 1
            super().add_log(
                i=processed_client_count,
                num_active_clients=len(selected_client_ids),
                start_time=start_time,
                global_epoch=global_epoch,
                lr=lr,
                selected_client_ids=selected_client_ids,
                metric=metric,
                logger=logger,
            )


        super().update_server_model(clients=self.clients) 
        return

    def evaluate_trained_model(
        self,
        dataset,
        batchnorm_dataset,
        logger,
        metric,
        global_epoch,
        malicious_client_ids
    ):  

        super().evaluate_trained_model(
            dataset=dataset,
            batchnorm_dataset=batchnorm_dataset,
            logger=logger,
            metric=metric,
            global_epoch=global_epoch,
            server_model_state_dict=self.server_model_state_dict,
            clients=self.clients,
        )

        return 