import argparse
from data import preprocess_data,OUTDIR,patient_level_data_to_train_test
from utils import TEST_ALL_DISEASE_SPLIT,TEST_ONLY_BELOW_50_AGE_SPLIT, \
    TEST_WOMEN_OVER_75_AGE_SPLIT,get_report,get_tsne_outputs,get_device

from torch.utils.data import WeightedRandomSampler,DataLoader
from neural_model import CNN1D_Classifier, CNN_NoRes1D_Classifier, MODEL_DICT
from collections import defaultdict,Counter
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import pickle as pkl
from scipy.special import softmax

### Attempting to introduce reproducibility #### 
random_seed_const = 2
torch.manual_seed(random_seed_const)
import random
random.seed(random_seed_const)
import numpy as np 
np.random.seed(random_seed_const)
from triggerless_attack import WitchesBrew

def train_model(data_loader, model, n_epochs,lr,device,verbose = 0):

    model.to(device)
    criterion = nn.CrossEntropyLoss()
    label_list =[]
    optimizer = optim.Adam(model.parameters(),lr=lr,weight_decay=0.001)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)
    for epoch in tqdm(range(n_epochs),disable=True):
        num_correct=0
        num_samples = 0
        for x,y,m in data_loader:
            x = x.to(device = device,dtype = torch.float)
            y = y.to(device = device, dtype = torch.long)
            model.train()
            output = model(x)
            criterion_loss = criterion(output,y)
            optimizer.zero_grad()
            criterion_loss.backward()
            optimizer.step()

            _,preds = output.max(1)
            num_correct += (preds == y).detach().cpu().numpy().sum()
            num_samples += preds.size(0)
            label_list.extend(y.detach().cpu().numpy().tolist())
            # print(output.size(),y.size())
            # break
        
        if verbose >0:
            print(f'Epoch {epoch} : Accuracy {num_correct/num_samples}')
    
    print(Counter(label_list))
    
    return model

def evaluate_model(data_loader, model,device,argmax = True):
    model.to(device)
    model.eval()
    predictions=[]
    correct_labels = []
    pats=[]
    with torch.no_grad():
        for x,y,m in data_loader:
            x = x.to(device = device,dtype = torch.float)
            y = y.to(device = device, dtype = torch.long)
            output = model(x)
            if argmax is True:
                _,preds = output.max(1)
                preds = preds.detach().cpu().numpy()
                predictions.extend(preds.tolist())
            else:
                detached_output = output.detach().cpu().numpy()
                predictions.append(detached_output)

            lbl = y.detach().cpu().numpy()
            correct_labels.extend(lbl.tolist())
    
    if argmax is False:
        predictions = np.concatenate(predictions,axis=0)
        predictions = softmax(predictions,axis=1)
    return predictions, correct_labels


def evaluate_poisoning(args,train_dataset,train_sample_importance_normalised,test_dataset,victim_index,neural_model):
    print('Evaluating Poisoning ...')
    test_dataloader = DataLoader(test_dataset,batch_size=args['testbatchsize'])

    sampler = WeightedRandomSampler(train_sample_importance_normalised,len(train_sample_importance_normalised),replacement = True)
    train_dataloader = DataLoader(train_dataset,batch_size=args['batchsize'],sampler=sampler)
    device = args['device']

    model_type = MODEL_DICT[args['modelname']][0]
    poisoned_model = model_type(1,5)

    poisoned_model = train_model(train_dataloader,model=poisoned_model,n_epochs=args['epochs'],lr=args['learningrate'],device=device)


    train_predictions, train_labels = evaluate_model(train_dataloader,neural_model,device,argmax=False)
    argmax_predictions = train_predictions.argmax(axis=1)
    get_report(argmax_predictions, train_labels)

    poisoned_predictions, correct_labels = evaluate_model(test_dataloader,poisoned_model,device,argmax=False)
    argmax_predictions = poisoned_predictions.argmax(axis=1)
    get_report(argmax_predictions, correct_labels)

    clean_predictions, correct_labels = evaluate_model(test_dataloader,neural_model,device,argmax=False)
    argmax_predictions = clean_predictions.argmax(axis=1)
    get_report(argmax_predictions, correct_labels)
    
    victim_label, victim_pat = test_dataset[victim_index][1:]
    victim_label = int(victim_label.numpy().item())
    victim_x = test_dataset[victim_index][0].unsqueeze(0).to(device)
    print(f'Victim patient {victim_pat} : label {victim_label}')

    with torch.no_grad():
        clean_outp = neural_model(victim_x)
        poisoned_outp = poisoned_model(victim_x)
        print(F.softmax(clean_outp))
        print(F.softmax(poisoned_outp))
        print(F.softmax(clean_outp)-F.softmax(poisoned_outp))
        victim_sample_diff = (F.softmax(clean_outp)-F.softmax(poisoned_outp)).detach().cpu().numpy()

    victim_class_poisoned_preds = poisoned_predictions[test_dataset.Y.detach().cpu().numpy()==victim_label]
    victim_class_clean_preds = clean_predictions[test_dataset.Y.detach().cpu().numpy()==victim_label]

    class_difference = np.mean(victim_class_clean_preds[:,victim_label] - victim_class_poisoned_preds[:,victim_label])

    print('class_difference',class_difference)


    victim_class_poisoned_preds = poisoned_predictions[(test_dataset.Y.detach().cpu().numpy()==victim_label) & (np.array(test_dataset.metadata)==victim_pat)]
    victim_class_clean_preds = clean_predictions[(test_dataset.Y.detach().cpu().numpy()==victim_label) & (np.array(test_dataset.metadata) == victim_pat)]

    victim_class_difference = np.mean(victim_class_clean_preds[:,victim_label] - victim_class_poisoned_preds[:,victim_label])
    print('victim_class_difference',victim_class_difference)
    
    return (victim_sample_diff.squeeze(), class_difference, victim_class_difference)



def get_clean_model(args):
    patients_dta = preprocess_data(ddir = args['cachedir'])
    train_dataloader,test_dataloader,train_dataset,test_dataset,train_sample_importance_normalised = \
        patient_level_data_to_train_test(patients_dta=patients_dta,patient_test_split=TEST_ALL_DISEASE_SPLIT,
                                    train_batch_size=args['batchsize'],test_batch_size=args['testbatchsize'])
    device = args['device']

    total_predictions, total_correct_labels = [],[]

    model_type, modelfilename, _ = MODEL_DICT[args['modelname']]
    model = model_type(1,5)

    trained_model_Flag = False
    if args['cache'] is True:
        try:
            model.load_state_dict(torch.load(modelfilename))
            model.eval()
            print(f'Model loaded from cache {modelfilename}.')
            trained_model_Flag = True

        except:
            print(f'Model load from cache {modelfilename} failed.')

    if trained_model_Flag is False:

        print(f'Training Base Model....')
        model = train_model(train_dataloader,model=model,n_epochs=args['epochs'],lr=args['learningrate'],device=device)

    model.to('cpu')
    torch.save(model.state_dict(),modelfilename)
    print(f'Model saved at {modelfilename}')

    predictions, correct_labels = evaluate_model(test_dataloader,model,device)
    total_predictions.extend(predictions)
    total_correct_labels.extend(correct_labels)

    get_report(total_predictions, total_correct_labels)
    #get_tsne_outputs(model,train_dataset,test_batch_size=args['testbatchsize'],device = device)
    return model, train_dataset, test_dataset, train_sample_importance_normalised

def main(args):
    neural_model, train_dataset, test_dataset,train_sample_importance_normalised = get_clean_model(args)
    attack_config={'victim_class':'1','victim_patient':None,'poison_num_samples':100,
                    'poison_lrs':[10,1,0.1,0.01],'n_epochs':args['attackepochs'],'inf_norm_limit':0.05,
                    'target_class':1, 'labels':[0,1,2,3,4]}
    attack = WitchesBrew(train_dataset=train_dataset,test_dataset=test_dataset,
                neural_model=neural_model,attack_config = attack_config,
                input_type='1D',gradient_matching_type=args['gm_type'])
    
    res_dict = {}
    poison_class_list = [None,4] if args['gm_type']!='seperated' and args['gm_type']!='seperated_projection' else [None]
    victim_class_list = [2,3,4]
    total_exp = len(poison_class_list)* len(victim_class_list)*args['restarts']
    p_bar = tqdm(range(total_exp),desc="Experiment Execution:")
    num_exp = 0 
    for poison_class in poison_class_list:
        #for victim_class in tqdm([_ for _ in range(5) if _ != attack_config['target_class']]):
        for victim_class in victim_class_list:
            res = []
            ASR = []
            for victim_sample in range(args['restarts']):
                num_exp+=1
                p_bar.update(1)
                victim_index,samples_to_poison_idx = attack.initialize_evil_experiment(victim_label = victim_class,victim_patient = None,poison_class = poison_class)
                poisoned_v,m_,loss = attack.generate_poisons(victim_index,samples_to_poison_idx,args['device'],restarts = 10)

                poisoned_train_data = train_dataset.create_poisoned_copy(poisoned_v,samples_to_poison_idx)
                (victim_sample_diff, class_difference, victim_class_difference) = \
                    evaluate_poisoning(args,poisoned_train_data,train_sample_importance_normalised,test_dataset,victim_index,neural_model)

                ASR_ = victim_sample_diff[victim_class]>0 and victim_sample_diff[attack_config['target_class']]<0
                res.append((victim_sample_diff, class_difference, victim_class_difference,ASR_))

            res_dict[(poison_class,victim_class,attack_config['target_class'])] = (np.mean([_[0][victim_class] for _ in res]),np.mean([_[1] for _ in res]),np.mean([_[2] for _ in res]),np.mean([_[3] for _ in res]),res)
    
    print(res_dict)

    print('''\n**********************
    Generating Short Report:''')
    for exp_setting in res_dict:
        print(f'Attack Success Rate (Context = {exp_setting}) = {res_dict[exp_setting][3]}')


    outfilename = MODEL_DICT[args['modelname']][2]
    print(f'Writing result to {outfilename}')
    with open(outfilename,'wb') as fout:
        pkl.dump(res_dict,fout)


if __name__ =="__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-d","--dataset",type=str, default ="cifar10",help="Dataset")
    parser.add_argument("-m","--modelfile",type=str, default = "None", help ="Model input file or directory")
    parser.add_argument("-modelname","--modelname",type=str, default = "cnn_nores1d", help ="Model cnn1d or cnn_nores1d")
    parser.add_argument("-a","--poisonmode", type =str,default="patch",help="Attack Method")
    parser.add_argument("-df","--datafraction", type =float,default=0.1,help="Attack Data Fraction")
    parser.add_argument("-l","--learningrate", type =float,default=0.00001,help="Learning Rate")
    parser.add_argument('-b','--batchsize',default=1024,type=int,help="Batch Size")
    parser.add_argument('-r','--restarts',default=2,type=int,help="Restarts for multiple poisons.")
    parser.add_argument('-tb','--testbatchsize',default=1024,type=int,help=" TestBatch Size")
    parser.add_argument('-epochs','--epochs',default=2,type=int,help="Epochs for training/Poisoned model training")
    parser.add_argument('-attackepochs','--attackepochs',default=10,type=int,help="Epochs for training/Poisoned model training")
    parser.add_argument('-cachedir','--cachedir',default=OUTDIR)
    parser.add_argument("-gmt","--gm_type",type=str, default =None,help="Gradient Matching Mode. Default is None (Standard Witch's Brew), seperated & seperated_projection")
    parser.add_argument('-ntqdm','--ntqdm', action='store_true')
    parser.add_argument('-cache','--cache', action='store_true')

    args = parser.parse_args()
    args = vars(args)
    device = get_device()
    print(f'Using device {device}')
    args['device']=device
    print(f'Arguments are {args}')
    main(args)
