from shutil import copyfile
import math
import torch
import logging
logger = logging.getLogger("logger")
import os
import numpy as np



class Helper:
    def __init__(self, current_time, params, name):
        self.current_time = current_time
        self.target_model = None
        self.local_model = None

        self.train_data = None
        self.test_data = None
        self.poisoned_data = None
        self.test_data_poison = None

        self.params = params
        self.name = name
        self.best_loss = math.inf
        self.pre_path = self.params['pre_path']
        try:
            os.mkdir(self.pre_path)
        except FileExistsError:
            logger.info('Folder already exists')

        self.folder_path = f'{self.pre_path}/model_{self.name}_{current_time}'
        try:
            os.mkdir(self.folder_path)
        except FileExistsError:
            logger.info('Folder already exists')
        logger.addHandler(logging.FileHandler(filename=f'{self.folder_path}/log.txt'))
        logger.addHandler(logging.StreamHandler())
        logger.setLevel(logging.DEBUG)
        logger.info(f'current path: {self.folder_path}')
        if not self.params.get('environment_name', False):
            self.params['environment_name'] = self.name

        self.params['current_time'] = self.current_time
        self.params['folder_path'] = self.folder_path
    

    def save_checkpoint(self, state, is_best, filename='checkpoint.pth.tar'):
        if not self.params['save_model']:
            return False
        torch.save(state, filename)

        if is_best:
            copyfile(filename, 'model_best.pth.tar')


    
    @staticmethod
    def compute_update_norm(weights_update):
        squared_sum = 0
        for name, data in weights_update.items():
            squared_sum += torch.sum(torch.pow(data, 2))
        return math.sqrt(squared_sum)
    
    @staticmethod
    def clip_update_norm(weights_update, max_norm):
        total_norm = Helper.compute_update_norm(weights_update)
        clip_coef = max_norm / (total_norm + 1e-6)
        if total_norm > max_norm:
            for name, data in weights_update.items():
                data.mul_(clip_coef)
       

    def compute_median_norm (self, submit_params_update_dict,agent_name_keys):

        local_norms =[]
        for i in range(0, len(agent_name_keys)):
            local_update = submit_params_update_dict[agent_name_keys[i]]
            local_norms.append(self.compute_update_norm(local_update))
        
        median_norm = np.median(local_norms)  

        return median_norm

    def fedavg_clientdp(self, submit_params_update_dict,agent_name_keys, clip_norm, target_model):
        """
        Perform FedAvg algorithm on model params

        """

        # clip
        if self.params['withDP']==True:
            for i in range(0, len(agent_name_keys)):
                local_update = submit_params_update_dict[agent_name_keys[i]]
                self.clip_update_norm(local_update, clip_norm)

        # init the data structure
        agg_params_update=  dict()
      
       
        for name, data in target_model.state_dict().items():
            if 'num_batches_tracked' in name:
                continue
            agg_params_update[name] = torch.zeros_like(data)

        # avg 
        for name, data in agg_params_update.items():
            # avg 
            for i in range(0, len(agent_name_keys)):
                client_params_update = submit_params_update_dict[agent_name_keys[i]]
                temp = client_params_update[name]
                data.add_(temp)
            # add noise
            if self.params['withDP']==True:
                noise=  torch.cuda.FloatTensor(data.shape).normal_(mean=0, std=clip_norm* self.params['noise_multiplier'])
                data.add_(noise)
            
      
        for name, layer in target_model.state_dict().items():
            if 'num_batches_tracked' in name:
                continue
            layer.add_(agg_params_update[name]* 1.0/len(agent_name_keys))

        

    def average_models_params(self, submit_params_update_dict,agent_name_keys,target_model):
        """
        Perform FedAvg algorithm on model params

        """
        # init the data structure
        agg_params_update=  dict()
      
       
        for name, data in target_model.state_dict().items():
            if 'num_batches_tracked' in name:
                continue
            agg_params_update[name] = torch.zeros_like(data)

        # avg 
        for name, data in agg_params_update.items():
            # avg 
            for i in range(0, len(agent_name_keys)):
                client_params_update = submit_params_update_dict[agent_name_keys[i]]
                temp = client_params_update[name]
                data.add_(temp)
            
            
        for name, layer in target_model.state_dict().items():
            if 'num_batches_tracked' in name:
                continue
            layer.add_(agg_params_update[name]* 1.0/len(agent_name_keys))




    def save_model_for_certify(self, model=None, epoch=0,run_idx=0):
        if model is None:
            model = self.target_model
        if self.params['save_model']:
            model_name = '{0}/model.pt.tar'.format(self.params['folder_path'])
            saved_dict = {'state_dict': model.state_dict(), 'epoch': epoch,
                          'lr': self.params['lr']}

            if epoch in self.params['save_on_epochs']:
                logger.info(f'Saving model on epoch {epoch}')
                self.save_checkpoint(saved_dict, False, filename=f'{model_name}.epoch_{epoch}.run_{run_idx}')
          
    def save_model(self, model=None, epoch=0, val_loss=0):
        if model is None:
            model = self.target_model
        if self.params['save_model']:
         
            model_name = '{0}/model_last.pt.tar'.format(self.params['folder_path'])
            saved_dict = {'state_dict': model.state_dict(), 'epoch': epoch,
                          'lr': self.params['lr']}
            self.save_checkpoint(saved_dict, False, model_name)
          
            
            if epoch % 1 ==0: # save at every epoch
                logger.info(f'Saving model on epoch {epoch}')
                self.save_checkpoint(saved_dict, False, filename=f'{model_name}.epoch_{epoch}')
            if val_loss < self.best_loss:
                self.save_checkpoint(saved_dict, False, f'{model_name}.best')
                self.best_loss = val_loss

   
