import argparse
from utils import get_device
from torch.utils.data import DataLoader,WeightedRandomSampler
import torch
import image_utils 
import train_utils
import numpy as np 
from tqdm import tqdm 
from sklearn.metrics import classification_report
import torch.nn.functional as F
from triggerless_attack import WitchesBrew



def evaluate_poisoning(args,train_dataset,train_sample_importance_normalised,
                       test_dataset,victim_index,neural_model):
    print('\t\tEvaluating Poisoning ...')
    device = args['device']

    if args['imbalance_correction']=='resampling':
        train_sampler = WeightedRandomSampler(train_dataset.get_sample_weights(),
                                        len(train_dataset),replacement = True)

        trainloader = DataLoader(train_dataset,sampler=train_sampler,
                                 batch_size=args['batchsize'])
        class_weights = None
    else:
        trainloader = DataLoader(train_dataset,batch_size=args['batchsize'])
        class_weights = train_dataset.get_class_weights()

    testloader = DataLoader(test_dataset,batch_size=args['testbatchsize'])

    poisoned_model = image_utils.get_image_model(args['modelname'],args['dataset'])

    modelstamp = image_utils.get_modelstamp(args)+'_poisoned'
    print(f'\tTraining Base Model with model stamp {modelstamp} ...')


    poisoned_model = train_utils.train_model(args,trainloader,
                                    model=poisoned_model,n_epochs=args['epochs'],
                                    lr=args['learningrate'],device=device,class_weight=class_weights)  

    poisoned_model.eval()
    

    victim_x = test_dataset[victim_index][0].unsqueeze(0).to(device)
    victim_label = test_dataset[victim_index][1]
    with torch.no_grad():
        clean_outp = neural_model(victim_x)
        victim_x = test_dataset[victim_index][0].unsqueeze(0).to(device)
        poisoned_outp = poisoned_model(victim_x)
        # print(F.softmax(clean_outp))
        # print(F.softmax(poisoned_outp))
        # print(F.softmax(clean_outp)-F.softmax(poisoned_outp))
        victim_sample_diff = (F.softmax(clean_outp)-F.softmax(poisoned_outp)).detach().cpu().numpy()

    #print(f'Victim sample diff = {victim_sample_diff}')

    return victim_sample_diff.squeeze()

def get_clean_model(args,train_dataset,test_dataset):

    device = args['device']

    if args['imbalance_correction']=='resampling':
        train_sampler = WeightedRandomSampler(train_dataset.get_sample_weights(),
                                        len(train_dataset),replacement = True)

        trainloader = DataLoader(train_dataset,sampler=train_sampler,
                                 batch_size=args['batchsize'])
        
        class_weights = None
    else:
        trainloader = DataLoader(train_dataset,batch_size=args['batchsize'])
        class_weights = train_dataset.get_class_weights()

    testloader = DataLoader(test_dataset,batch_size=args['testbatchsize'])

    model = image_utils.get_image_model(args['modelname'],args['dataset'])
    print(f'Using the model architecture: ')
    print(model)

    modelstamp = image_utils.get_modelstamp(args)
    print(f'Training Base Model with model stamp {modelstamp} ...')

    model = train_utils.train_model(args,trainloader,
                                    model=model,n_epochs=args['epochs'],
                                    lr=args['learningrate'],device=device,class_weight = class_weights)
    
    predictions, correct_labels = train_utils.evaluate_model(testloader,model,device)

    print('\t\t'+classification_report(correct_labels,predictions).replace('\n','\t\t\n'))

    return model

def main(args):
    clean_trainset, clean_testset,clean_imbalanced_trainset,clean_imbalanced_testset = \
          image_utils.get_dataset(args['dataset'],args['modelname'],class_ratio=args['class_ratio'])

    clean_model = get_clean_model(args,clean_imbalanced_trainset,clean_imbalanced_testset)

    attack_config={'poison_num_samples':args['num_poison'],
                    'poison_lrs':[10,1,0.1,0.01],'n_epochs':args['attackepochs'],
                    'inf_norm_limit':0.05,'victim_class':'1','target_class':0,
                    'labels':clean_imbalanced_trainset.label_set}
    

    attack = WitchesBrew(train_dataset=clean_imbalanced_trainset,
                         test_dataset=clean_imbalanced_testset,
                neural_model=clean_model,attack_config = attack_config,
                input_type='2D',gradient_matching_type=args['gm_type'])
    
    res_dict = {}

    poison_class_list = [None,4] if args['gm_type']!='seperated' and args['gm_type']!='seperated_projection' else [None]
    victim_class_list = [7,8,9]

    total_exp = len(poison_class_list)* len(victim_class_list)*args['restarts']
    p_bar = tqdm(range(total_exp),desc="Experiment Execution:")
    num_exp = 0 
    for poison_class in poison_class_list:
        for victim_class in victim_class_list:
            res=[]
            for victim_sample in range(args['restarts']):
                num_exp +=1
                p_bar.update(1)
                #print(f'Executing experiment {num_exp} / {total_exp}')
                victim_index,samples_to_poison_idx = attack.initialize_evil_experiment(victim_label = victim_class,victim_patient = None,poison_class = poison_class)
                poisoned_v,m_,loss = attack.generate_poisons(victim_index,samples_to_poison_idx,args['device'],restarts = 10)

                poisoned_train_data = clean_imbalanced_trainset.create_poisoned_copy(poisoned_v,samples_to_poison_idx)
                victim_sample_diff = evaluate_poisoning(args,poisoned_train_data,
                                                        clean_imbalanced_trainset.get_sample_weights(),
                                                        clean_imbalanced_testset,
                                                        victim_index,clean_model)

                ASR_ = victim_sample_diff[victim_class]>0 and victim_sample_diff[attack_config['target_class']]<0
                res.append((victim_sample_diff,ASR_))

            res_dict[(poison_class,victim_class,attack_config['target_class'])] = (np.mean([_[-1] for _ in res]),res)

    print(res_dict)


    print('''**********************
    Generating Short Report:''')
    for exp_setting in res_dict:
        print(f'Attack Success Rate (Context = {exp_setting}) = {res_dict[exp_setting][0]}')

if __name__ =="__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-d","--dataset",type=str, default ="cifar10",help="Dataset")
    parser.add_argument("-gmt","--gm_type",type=str, default =None,help="Gradient Matching Mode. Default is None (Standard Witch's Brew), seperated & seperated_projection")
    parser.add_argument("-m","--modelfile",type=str, default = "None", help ="Model input file or directory")
    parser.add_argument("-modelname","--modelname",type=str, default = "resnet18", help ="Model Name")
    parser.add_argument("-a","--poisonmode", type =str,default="patch",help="Attack Method")
    parser.add_argument("-imb","--imbalance_correction", type =str,default="reweighting",help="Imbalance correction Method")
    parser.add_argument("-df","--datafraction", type =float,default=0.1,help="Attack Data Fraction")
    parser.add_argument("-ir","--imbalanceratio", type =float,default=0.1,help="Imbalance Ratio")
    parser.add_argument("-cr","--class_ratio", type =float,default=0.1,help="Class Ratio")
    parser.add_argument("-l","--learningrate", type =float,default=0.00001,help="Learning Rate")
    parser.add_argument('-b','--batchsize',default=1024,type=int,help="Batch Size")
    parser.add_argument('-np','--num_poison',default=100,type=int,help="Number of Poisons ")
    parser.add_argument('-r','--restarts',default=2,type=int,help="Restarts for multiple poisons.")
    parser.add_argument('-rep','--repeat',default=2,type=int,help="Experiment Repeats")
    parser.add_argument('-tb','--testbatchsize',default=1024,type=int,help=" TestBatch Size")
    parser.add_argument('-epochs','--epochs',default=2,type=int,help="Epochs for training/Poisoned model training")
    parser.add_argument('-attackepochs','--attackepochs',default=2,type=int,help="Epochs for training/Poisoned model training")
    parser.add_argument('-cachedir','--cachedir',default="cache")
    parser.add_argument('-ntqdm','--ntqdm', action='store_true')


    args = parser.parse_args()
    args = vars(args)
    device = get_device()
    args['device']=device
    args['cache'] = True
    print(f'Arguments are {args}')
    main(args)
