import os
from inspect import signature

import torch

import flwr as fl
from src.core.utils import (
    get_optimizer,
    set_seed
)


class BaseClient(fl.client.NumPyClient):
    def __init__(self, 
                 cid, 
                 config, 
                 net_builder, 
                 train_loader):
        
        self.cid = int(cid)
        self.config = config
        self.model_cfgs = self.config['Model']
        self.data_cfgs = self.config['Dataset']['Client']
        self.train_cfgs = config['Training']['Client']
        
        self.seed = self.config['seed']
        set_seed(self.seed + self.cid)
        
        # Exp. device setting
        self.gpu = self.config['gpu']
        os.environ["CUDA_VISIBLE_DEVICES"] = f"{self.gpu}"
        self.device = torch.device(f"cuda" if torch.cuda.is_available() else 'cpu')

        # Set Dataset
        self.train_loader = train_loader
        
        # Set Model
        self.num_classes = self.config['Dataset']['num_classes']
        self.net_builder = net_builder
        self.model = self.set_model()        
        self.model.to(self.device)
        
        self.use_ema = self.model_cfgs['use_ema']
        self.ema_save_path = os.path.join(self.config['save_dir'], self.config['save_name'], 
                                          'ema_model.pth')
        
        # Training
        self.optimizer = self.set_optimizer()
        self.static_bn = self.config['Training']['static_bn']
        self.clip_grad = self.config['Training']['clip']
        
        self.use_ema_pseudo = self.train_cfgs['use_ema_pseudo']
        self.start_ema_pseudo = self.train_cfgs['start_ema_pseudo']
        self.use_pmask = self.train_cfgs['use_pmask']
        
        self.local_epochs = self.train_cfgs['local_epochs']
        self.epoch = 0
        

    # ============================================================================================================== #  
    def set_model(self):
        model = self.net_builder(num_classes=self.num_classes, 
                                 pretrained=self.config['Model']['use_pretrain'],
                                 pretrained_path=self.config['Model']['pretrain_path'])
        return model
    

    def set_optimizer(self):
        optimizer = get_optimizer(net=self.model, 
                                  optim_name=self.train_cfgs['optim'], 
                                  lr=self.train_cfgs['lr'], 
                                  momentum=self.train_cfgs['momentum'], 
                                  weight_decay=self.train_cfgs['weight_decay'])
        return optimizer


    def process_batch(self, **kwargs):
        input_args = signature(self.train_step).parameters
        input_args = list(input_args.keys())
        input_dict = {}

        for arg, var in kwargs.items():
            if not arg in input_args:
                continue
            if var is None:
                continue
            # send var to cuda
            if isinstance(var, dict):
                var = {k: v.to(self.device) for k, v in var.items()}
            else:
                var = var.to(self.device)
            input_dict[arg] = var
        return input_dict
    

    def process_all_batch(self, **kwargs):
        return {k: (v.to(self.device) if hasattr(v, 'to') else v) for k, v in kwargs.items() if v is not None}
    
    
    def train_step(self, *args, **kwargs):
        """
        train_step specific to each algorithm
        """
        # implement train step for each algorithm
        # compute loss
        # update model 
        # record res_dict
        # return res_dict
        raise NotImplementedError
    
    
    def fit(self, parameters, config):
        self.set_parameters(parameters)
        
        self.server_round = config["server_round"]
        set_seed(self.seed + self.server_round + self.cid)
        
        curr_lr = float(config.get("current_lr", self.train_cfgs['lr']))
        for g in self.optimizer.param_groups:
            g['lr'] = curr_lr
        
        self.model.train()

        for _ in range(self.local_epochs):
            for data in self.train_loader:
                self.res_dict = self.train_step(**self.process_batch(**data))

        return self.get_parameters(), len(self.train_loader.dataset), self.res_dict


    def get_parameters(self, config=None):
        return [val.cpu().numpy() for val in self.model.state_dict().values()]


    def set_parameters(self, parameters):
        state_dict = self.model.state_dict()
        for k, v in zip(state_dict.keys(), parameters):
            state_dict[k] = torch.tensor(v)
        self.model.load_state_dict(state_dict)
        
        # Freeze BN after loading model
        if self.static_bn:
            self.freeze_bn_stats()
        

    def freeze_bn_stats(self):
        for m in self.model.modules():
            if isinstance(m, torch.nn.modules.batchnorm._BatchNorm):
                m.eval()  # This prevents running_mean and running_var updates