import argparse
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
os.environ["CUDA_VISIBLE_DEVICES"] = "5"
import datetime
import logging
import torch
import local_train_adv as train
import test
from image_helper import ImageHelper
import utils.csv_record as csv_record
import yaml
import time
import numpy as np
import random
import config
from opacus import PrivacyEngine
logger = logging.getLogger("logger")
def set_random_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    random.seed(seed)
    np.random.seed(seed)


if __name__ == '__main__':
    
    parser = argparse.ArgumentParser(description='PPDL')
    parser.add_argument('--params', dest='params')
    args = parser.parse_args()
    with open(f'./{args.params}', 'r') as f:
        params_loaded = yaml.load(f)

    params_loaded['adversary_list'] = list(range(1,params_loaded['num_adv'] +1 ))
    if  params_loaded['dba']==True   and params_loaded['is_poison']==True:
        pattern=  params_loaded['poison_pattern']
        per_pixel = int(len(pattern)/len(params_loaded['adversary_list']))
        print(per_pixel)
        for i in range(len(params_loaded['adversary_list'])-1):
            adv_name= params_loaded['adversary_list'][i]
            params_loaded[str(adv_name)+'_poison_pattern']= pattern[i*per_pixel:per_pixel*(i+1)]
            print(str(adv_name)+'_poison_pattern', params_loaded[str(adv_name)+'_poison_pattern'])
        i = len(params_loaded['adversary_list'])-1
        adv_name= params_loaded['adversary_list'][i]
        params_loaded[str(adv_name)+'_poison_pattern']= pattern[i*per_pixel:-1]
        print(str(adv_name)+'_poison_pattern', params_loaded[str(adv_name)+'_poison_pattern'])


    set_random_seed(0) # fix the seed for create local datasets

    current_time = datetime.datetime.now().strftime('%b.%d_%H.%M.%S')
  
    if params_loaded['type'] == config.TYPE_CIFAR:
        helper = ImageHelper(current_time=current_time, params=params_loaded,
                             name=params_loaded.get('name', 'cifar'))
        helper.load_data()
    elif params_loaded['type'] == config.TYPE_MNIST:
        helper = ImageHelper(current_time=current_time, params=params_loaded,
                             name=params_loaded.get('name', 'mnist'))
        helper.load_data()
    else:
        helper = None

    logger.info(f'load data done')

    for run_idx in range(0, params_loaded['n_runs']):
        logger.info(f'start run number:{run_idx}')

        torch.cuda.empty_cache()
        set_random_seed(run_idx) # set the pre-defined seed for dp randomness
        helper.create_model()
        logger.info(f'create model done')

        if params_loaded['withDP']==True:
            g_optimizer = torch.optim.SGD(helper.target_model.parameters(), lr=helper.params['lr'])
        
            global_privacy_engine = PrivacyEngine(
                helper.target_model,
                batch_size = params_loaded['no_models'], # selected clients num
                sample_size = params_loaded['number_of_total_participants'] , # total number of clients
                alphas=[1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64)),
                noise_multiplier = params_loaded['noise_multiplier'], 
                max_grad_norm =  params_loaded['max_clip_norm'] ) # for weight norm 
            
            global_privacy_engine.attach(g_optimizer)

        ### Create models
        if helper.params['is_poison']:
            logger.info(f"Poison following participants: {(helper.params['adversary_list'])}")
    
        # save parameters:
        with open(f'{helper.folder_path}/params.yaml', 'w') as f:
            yaml.dump(helper.params, f)


        for epoch in range(helper.start_epoch, helper.params['epochs'] + 1):
            
            agent_name_keys = np.random.choice(range(params_loaded['number_of_total_participants']),
                                        max(params_loaded['no_models'], 1),
                                        replace=False)
            start_time = time.time()

            submit_params_update_dict, num_samples_dict, num_poisoned_samples_dict = train.FLtrain(
                helper=helper,
                start_epoch=epoch,
                local_model=helper.local_model,
                target_model=helper.target_model,
                is_poison=helper.params['is_poison'],
                agent_name_keys=agent_name_keys)

            clip_norm=helper.params['max_clip_norm']
            helper.fedavg_clientdp(submit_params_update_dict,
                                agent_name_keys,
                                clip_norm= clip_norm,
                                target_model=helper.target_model)

            epoch_loss, epoch_acc, epoch_corret, epoch_total = test.Mytest(helper=helper, epoch=epoch,
                                                                        model=helper.target_model, is_poison=False,
                                                                        visualize=True, agent_name_key="global")
            p_epoch_loss=0
            epoch_acc_p =0
            if ( helper.params['adv_method']==2 or helper.params['adv_method']==3):
                p_epoch_loss, epoch_acc_p, epoch_corret, epoch_total = test.Mytest_poison(helper=helper,
                                                                                        epoch=epoch,
                                                                                        model=helper.target_model,
                                                                                        is_poison=True,
                                                                                        visualize=True,
                                                                                        agent_name_key="global")
                if helper.params['record_p']== True:
                    csv_record.posiontest_result.append(
                        ["global", epoch, p_epoch_loss, epoch_acc_p])

            if params_loaded['withDP']==True:
                global_privacy_engine.steps = epoch #   just assign the epoch
                epsilon, best_alpha  = global_privacy_engine.get_privacy_spent(params_loaded['delta'])                                                  
                epsilon=round(epsilon,4) # 4 digit  
                logger.info('___GlobalDP, epoch: {},  accuracy: {:.4f} epsilon: {:.4f}, clip norm: {:.4f}, noise_mul:{} delta: {} for alpha: {}'
                        .format( epoch, epoch_acc, epsilon ,clip_norm,  params_loaded['noise_multiplier'],  params_loaded['delta'],best_alpha))
                csv_record.dp_result.append([epoch,  epsilon ,epoch_acc,epoch_loss,
                                        p_epoch_loss,epoch_acc_p])
            else:
                csv_record.dp_result.append([epoch, 0.0, epoch_acc,epoch_loss,
                                        p_epoch_loss,epoch_acc_p])
            if epoch == helper.start_epoch:
                logger.info(f'Done one epoch in {time.time() - start_time} sec.')

         
            helper.save_model_for_certify(epoch=epoch,run_idx=run_idx)
            csv_record.save_result_csv(helper.folder_path,run_idx=run_idx)
        
        csv_record.clear_csv()