from __future__ import annotations

import copy
import datetime
import numpy as np
import sys
import time
import math
import torch
import torch.nn.functional as F
import models
from itertools import compress
from config import cfg

# from torchstat import stat

from utils.api import (
    to_device,  
    collate
)

from _typing import (
    DatasetType,
    OptimizerType,
    DataLoaderType,
    ModelType,
    MetricType,
    LoggerType,
    ClientType,
    ServerType
)

from models.api import (
    create_model
)

from optimizer.api import create_optimizer

from data import make_data_loader

from .clientBase import ClientBase


class ClientFedAvg(ClientBase):

    def __init__(
        self, 
        client_id: int, 
        model: ModelType, 
        data_split: list[int],
    ) -> None:

        super().__init__()
        self.client_id = client_id
        self.data_split = data_split
        self.model_state_dict = {k: v.cpu() for k, v in model.state_dict().items()}
        optimizer = create_optimizer(model, 'client')
        self.optimizer_state_dict = optimizer.state_dict()
        self.active = False

    @classmethod
    def create_clients(
        cls,
        model: ModelType, 
        data_split: dict[str, dict[int, list[int]]],
    ) -> dict[int, object]:
        '''
        Create clients which organized in dict type
        
        Parameters
        ----------
        model: ModelType
        data_split: dict[str, dict[int, list[int]]]

        Returns
        -------
        dict[int, object]
        '''
        client_id = torch.arange(cfg['num_clients'])
        clients = [None for _ in range(cfg['num_clients'])]
        for m in range(len(clients)):
            clients[m] = ClientFedAvg(
                client_id=client_id[m], 
                model=model, 
                data_split={
                    'train': data_split['train'][m], 
                    'test': data_split['test'][m]
                },
            )
        return clients
        
    def train(
        self, 
        dataset: DatasetType, 
        lr: int, 
        metric: MetricType, 
        logger: LoggerType,
        malicious_client_ids
    ) -> None:

        # print(f'Client {self.client_id} is training')
        model = create_model()
        model.load_state_dict(self.model_state_dict, strict=False)
        original_model_state_dict = copy.deepcopy(self.model_state_dict)
        self.optimizer_state_dict['param_groups'][0]['lr'] = lr
        optimizer = create_optimizer(model, 'client')
        optimizer.load_state_dict(self.optimizer_state_dict)
        model.train(True)

        data_loader = make_data_loader(
            dataset={'train': dataset}, 
            tag='client'
        )['train'] 

        gradient_update_num = 0
        break_for_gradient_update = False

        if cfg['malicious_way'] == 'random' and self.client_id in malicious_client_ids:
            
            # model_weight_collector = copy.deepcopy(list(model.parameters()))
            # for param_index, param in enumerate(model.parameters()):
            #     fed_prox_reg += ((mu / 2) * torch.norm((param - global_weight_collector[param_index]))**2)
            # print(')))', f"{cfg['model_tag']}", flush=True)
            grad = copy.deepcopy(self.model_state_dict)
            # a2 = copy.deepcopy(self.model_state_dict)
            for k, v in self.model_state_dict.items():
                if 'weight' in k or 'bias' in k:
                    # tmp_v = v.data.new_zeros(v.size())
                    self.model_state_dict[k] = torch.FloatTensor(v.size()).uniform_(-0.25, 0.25)


                    # grad[k] = copy.deepcopy(-3 * (self.model_state_dict[k] - original_model_state_dict[k]))
                    # self.model_state_dict[k] = copy.deepcopy(original_model_state_dict[k] - grad[k])
            return
        
        for local in range(1, cfg['local_epoch']+1):
            for i, input in enumerate(data_loader):
                input = collate(input)
                input_size = input['data'].size(0)
                input = to_device(input, cfg['device'])
                optimizer.zero_grad()
                output = model(input)
                output['loss'].backward()

                # for param_index, param in enumerate(model.parameters()):
                #     print(f'client param: {param.grad}')

                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=cfg['max_clip_norm'])
                optimizer.step()
                evaluation = metric.evaluate(
                    metric.metric_name['train'], 
                    input, 
                    output
                )
                logger.append(
                    evaluation, 
                    'train', 
                    n=input_size
                )

                gradient_update_num += 1
                if 'gradient' in cfg['local_training_type'] \
                    and int(cfg['local_training_type'][-1]) == gradient_update_num:
                    # print(f'gradient_break: {gradient_update_num}')
                    break_for_gradient_update = True
                    break
            
            if break_for_gradient_update:
                break

        self.optimizer_state_dict = optimizer.state_dict()
        self.model_state_dict = {k: v.cpu() for k, v in model.state_dict().items()}
        
        
        return
