import torch as t
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from utils.adversarial_attacks_utils import *
from sklearn.model_selection import KFold
from utils.data_utils import TransformSubset, PhysionetMMMI,ReshapeTensor
from utils.data_importers import *
from torchvision import transforms
import argparse
from sklearn.metrics import confusion_matrix
import os 
import itertools
import argparse
import utils.eegnet



def loss_function(model,
                    sample,
                    v,
                    run_params,
                    device,
                    channel_distances=0,
                    fs=160,
                    weights=0,sizes=0):

    derivative_loss = 0
    
    #Derivative 
    if run_params['Naturalism']=='derivative' or run_params['Naturalism']=='both':    
        alpha_1 =  1e-6/run_params['Epsilon']     
        derivative_loss = t.roll(v,1)-v
        derivative_loss = alpha_1*t.norm(derivative_loss,1)
    
    #Smoothed perturbation
    elif run_params['Naturalism']=='gaussian' or run_params['Naturalism']=='both':    
        v = smooth_perturbation(v,weights,sizes,device)
    
    # Spatial constraints
    if run_params['Head Model']:
        # If ablated HM, train with lambdas=[0,0]
        lambdas=[0,0] if run_params['AblatedHM'] else run_params['Lambdas']
        v = add_spatial_constraints(v,channel_distances, fs,run_params['Attacked channels'], device, lambdas=lambdas)
    
    # FFT filtering
    xbp = bandpass_torch(sample + v, run_params['frequency_band'][0],run_params['frequency_band'][1], fs_eeg=fs,device=device) 
    
    # CE Loss Forward pass
    ce_loss = nn.CrossEntropyLoss()
    classification_result = model(xbp)
    CE_Loss = ce_loss(classification_result,t.tensor([run_params['target']]*sample.shape[0]).to(device))
    loss = CE_Loss

    return t.sum(loss)-derivative_loss

def init_perturbation(run_params,dataset_params,batch_init_shape='', device =t.device('cuda:0')):
    ''' 
    If the run is UAP then initialize with only one signal in the batch_size dimension
    otherwise use the same batch size as the sample
    '''
     
    # If the head is modeled then initialize with only one channel which is then propagated
    if run_params['Head Model']:
        channel_init_shape = 1
    else:
        channel_init_shape = dataset_params['channels']
        
    if run_params['Perturbation type']=='attack':
        #If the run is PGD initialize with random gaussian, otherwise just with zeros
        if run_params['PGD']:
            v =  2*run_params['Epsilon'] *t.rand(size=(batch_init_shape,channel_init_shape,dataset_params['datapoints']),device=device)-run_params['Epsilon'] 
        else:
            v= t.zeros(size=(batch_init_shape,channel_init_shape,dataset_params['datapoints']),device=device)
    elif run_params['Perturbation type']=='random_noise':
        v = run_params['Epsilon']*t.sign(t.normal(0,1,size=(batch_init_shape,channel_init_shape,dataset_params['datapoints']),device=device))
    if dataset_params['name']=='PhysioNet':
        v=v.unsqueeze(1)
    return v
        
def generate_attack(device, test_name,run_params,dataset_params):
    
    # 1 Parameters of run
    params = {'batch_size': run_params['batch_size'],
              'shuffle': not run_params['Assessment metrics']}
    
    N_RUNS = dataset_params['runs']

    # 2 Initialize array to gather ASR and Perturbation if it's UAP
    cv_sr_acc = np.zeros((N_RUNS))
    if run_params['UAP']:
        v_folds = t.zeros((N_RUNS,1,dataset_params['datapoints'])) if run_params['Head Model'] else  t.zeros((N_RUNS,dataset_params['channels'],dataset_params['datapoints']))
    else:
        v_folds = t.zeros((N_RUNS,1,1))
    distance_array = []

    if dataset_params['name']=='PhysioNet':
        kf = KFold(n_splits = N_RUNS)

        # 3 Load and prepare data 
        for fold, (train_idx, valid_idx) in enumerate(kf.split(dataset_params['dataset'])):

            dataset_params = get_dataset_model_and_loaders(fold,device,dataset_params, params,train_idx=train_idx, valid_idx=valid_idx)
            results = gather_results(fold,distance_array,
                                        run_params,
                                        dataset_params,
                                        device) 
            if run_params['Assessment metrics']:
                attack_samples,samples,true_labels,prediction_result,attack_labels, distance_array = results
            else:
                cv_sr_acc[fold],v_folds[fold]=results
            if run_params['Early Stop']!=0:
                if (fold+1) == run_params['Early Stop']:
                    return  attack_samples,samples,true_labels,prediction_result,attack_labels, distance_array

    elif dataset_params['name']=='BCI-Competition':
        for ind_subj in range(9):
            dataset_params = get_dataset_model_and_loaders(ind_subj,device,dataset_params, params)
            results = gather_results(ind_subj,distance_array,
                                        run_params,
                                        dataset_params,
                                        device)
            if run_params['Assessment metrics']:
                attack_samples,samples,distance_array = results
            else:
                cv_sr_acc[ind_subj],v_folds[ind_subj] = results
            if run_params['Early Stop']!=0:
                if (ind_subj+1) == run_params['Early Stop']:
                    return attack_samples,samples,distance_array

    if run_params['Assessment metrics']:
        return attack_samples, samples.cpu(), distance_array

    return cv_sr_acc,v_folds
  
def gather_results(run,
                distance_array,
                run_params,
                dataset_params,
                device):
    
    # 1 Initialize perturbation and loader if it's a UAP
    if run_params['UAP']:
        v = init_perturbation(run_params=run_params,dataset_params=dataset_params, batch_init_shape=1,device=device)
        loader = dataset_params['train_loader']
    else: 
        loader = dataset_params['val_loader']
        
    for epoch in range(run_params['MaxEpochs']):

        # 2 Initialize tensors to gather results
        prediction_result = t.empty(0).to(device)
        true_labels = t.empty(0).to(device)
        attack_labels = t.empty((0)).to(device)
        if run_params['Assessment metrics']: 
            samples = t.empty((0,dataset_params['channels'],dataset_params['datapoints'])).to(device)
            attack_samples = t.empty((0,dataset_params['channels'],dataset_params['datapoints']))#.to(device)

        for ind,(sample,label) in enumerate(loader):

            # 3 Initialize perturbation if it's not UAP and prepare the sample
            if not run_params['UAP']:
                v = init_perturbation(run_params=run_params,dataset_params=dataset_params, batch_init_shape=sample.shape[0],device=device)
            sample = sample.to(device)

            if run_params['Perturbation type']=='attack':
                
                # If it's FGSM the step size is equal to the maximum amplitude
                run_params['alpha'] = run_params['Epsilon'] if not run_params['PGD'] else run_params['Epsilon']/2
                for iteration in range(run_params['Iterations']):

                    # 4 Activate the attack's gradient recording
                    v.requires_grad_(True)

                    # 5 Compute the loss 
                    if run_params['loss function']=='Liu et al':
                        loss = loss_function_LiuEtAl(dataset_params['model'],
                                                sample,
                                                v,
                                                run_params,
                                                device=device,
                                                fs=dataset_params['sampling frequency'])
                    else:
                        loss = loss_function(dataset_params['model'],
                                                    sample,
                                                    v,
                                                    run_params=run_params,
                                                    device = device,
                                                    channel_distances=dataset_params['distances'],
                                                    fs=dataset_params['sampling frequency'],
                                                    weights=run_params['weights'],
                                                    sizes=run_params['sizes'])

                    # 6 Backpropagation 
                    loss.backward()

                    # 7 Update the perturbation 
                    with t.no_grad():
                        v = v - run_params['alpha'] * v.grad.sign()
                        v = t.clamp( v,-run_params['Epsilon'] ,run_params['Epsilon'] )

                    # 8 Update the step size if it's a PGD run
                    run_params['alpha'] = (0.1-run_params['Epsilon'] /2)/run_params['Iterations']*(iteration+1)+run_params['Epsilon'] /2 if run_params['PGD'] else run_params['Epsilon'] 
            else:
                loss=t.tensor(-1)
            # 9 Update perturbation if running with the head model
            vp = add_spatial_constraints(v.detach(),dataset_params['distances'],dataset_params['sampling frequency'],run_params['Attacked channels'], device, lambdas=run_params['Lambdas']) if run_params['Head Model'] else v.detach()

            # 10 Perform the attack
            attack = sample + vp

            # 11 Collect classification of original sample
            y_pred = bandpass_torch(sample.detach().clone(),run_params['frequency_band'][0],run_params['frequency_band'][1],fs_eeg=dataset_params['sampling frequency'],device=device)
            y_pred = dataset_params['model'](y_pred).argmax(dim=1) # get the index
            prediction_result = t.cat((prediction_result,y_pred))

            # 12 Collect classification of attack
            attack_classification = bandpass_torch(attack, run_params['frequency_band'][0],run_params['frequency_band'][1], fs_eeg=dataset_params['sampling frequency'],device=device)
            attack_classification = dataset_params['model'](attack_classification.detach()).argmax(dim=1)
            attack_labels = t.cat((attack_labels,attack_classification))

            # 13 Collect the true labels
            true_labels = t.cat((true_labels,label.to(device)))
            if run_params['Assessment metrics']: 
                if dataset_params['name']=='PhysioNet':
                    samples = t.cat((samples,sample.squeeze()))
                    attack_samples = t.cat((attack_samples,attack.squeeze().cpu()))
                elif dataset_params['name']=='BCI-Competition':
                    samples = t.cat((samples,sample))
                    attack_samples = t.cat((attack_samples,attack.cpu()))

        # 14 Get ASR for current epoch 
        asr= get_attack_success_rate(prediction_result, true_labels, attack_labels, run_params['target'],device)
        print(f'Eps: {run_params["Epsilon"]} run: {run+1}/{dataset_params["runs"]} epoch: {epoch+1}/{run_params["MaxEpochs"]}, loss= {loss.item():.4f} ASR: {100*asr:.4f}%')
        cv_sr_acc = asr.cpu().numpy()

        #15 Compute distances when running to recover plots
        if run_params['Assessment metrics']: 

            #Euclidean distance
            distance_array.append(get_euclidean_distance(samples,attack_samples,dataset_params['sampling frequency'],attacked_channels=False))

            # Cosine similarity
            distance_array.append(get_cosine_similarity(samples,attack_samples,dataset_params['sampling frequency'],device, attacked_channels=False))

            #Cross correlation
            distance_array.append(get_x_correlation(samples,attack_samples,dataset_params['sampling frequency'],attacked_channels=False,freq_domain=False))

    # 16 If UAP, save the perturbation and test it 
    if run_params['UAP']:
        v_folds = v.squeeze().detach().clone().to('cpu')
        vp = add_spatial_constraints(v.detach(),dataset_params['distances'],dataset_params['sampling frequency'],run_params['Attacked channels'], device, lambdas=run_params['Lambdas']) if run_params['Head Model'] else v.detach()
        asr = evaluate_attack(dataset_params['model'], dataset_params['val_loader'], vp, target=run_params['target'], frequency_band = run_params['frequency_band'], fs_eeg=dataset_params['sampling frequency'],device=device) 
        cv_sr_acc = asr.cpu().numpy()
    else:
        v_folds=0
    if run_params['Assessment metrics']: 
        return attack_samples,samples.cpu(),true_labels.cpu(),prediction_result.cpu(),attack_labels.cpu(), distance_array

    else:
        return cv_sr_acc,v_folds
         
def get_dataset_parameters(run_params):
    name=run_params['Dataset']
    if name=='PhysioNet':
        data_path = os.path.abspath(os.path.join(os.path.dirname( __file__ ), '1-Data/1-PhysioNet/'))
        channels=64
        datapoints=480
        sampling_frequency=160
        runs = 5 # Folds
    elif name=='BCI-Competition':
        data_path = os.path.abspath(os.path.join(os.path.dirname( __file__ ), '1-Data/2-BCI-Competition/'))
        channels=22
        datapoints=1125
        sampling_frequency=250
        runs = 9 # number of subjects
    dataset_params = {'name':name,
                 'distances': get_distances(run_params['Attacked channels'],dataset=name),
                 'channels':channels,
                 'datapoints':datapoints,
                 'sampling frequency':sampling_frequency,
                 'runs':runs,
                 'data_path':data_path}
    if name=='PhysioNet':
        transform = transforms.Compose([ReshapeTensor()])
        data_set = PhysionetMMMI(datapath=dataset_params['data_path']+'/2-Data',num_classes=3, transform=transform)
        dataset_params['dataset']=data_set
    return dataset_params

def run_attack(run_params, iter_list,eps_list,restarts,device,multiple_channels=False, metrics_export=False, export=True):
    
    # If it's not running with spatial propagation model, update lambda_list with dummy variable
    if run_params['Head Model']:
        if metrics_export:
            lambda_list=run_params['Lambda list']
        else:
            lambda_list = [[1,0.1],[1,0.3],[1,0.563],
                        [5,0.1],[5,0.3],[5,0.563],
                        [15,0.1],[15,0.3],[15,0.563]]
    else:
        lambda_list=[[0,0]]

    dataset_params = get_dataset_parameters(run_params)

    # Set list for plotting ASR attacking different channels
    if multiple_channels:
        attacked_channel_list=dataset_params['channels']
        eps_list=[50] if run_params['Dataset']=='PhysioNet' else [5]
        lambda_list=[[1,0.3],[5,0.3],[15,0.3]]
    else:
        attacked_channel_list=1

    for restart in range(restarts):
        
        # Set seeds and print parameters of run
        t.manual_seed(9823752+restart)
        np.random.seed(1528102+restart)
        print(f"{run_params['Dataset']} \
        {get_key_from_value(run_params, 'UAP')} \
        {get_key_from_value(run_params, 'PGD')} \
        {get_key_from_value(run_params, 'Head Model')} \
        {get_key_from_value(run_params, 'Naturalism')} \
        GPU:{device} \
        Run#{restart+1}")

        # Prepare file names
        test_name = (f'{run_params["Dataset"]}_{get_key_from_value(run_params, "UAP")}_'
                        f'{get_key_from_value(run_params,"Attacked channels",multiple_channels)}_'
                        f'{get_key_from_value(run_params, "PGD")}_'
                        f'{get_key_from_value(run_params, "Naturalism")}_'
                        f'r{restart}_{get_key_from_value(run_params, "AblatedHM")}')     
        file_name = test_name + f"_eps{eps_list}_l{lambda_list}_epochs{run_params['MaxEpochs']}"

        asr_array=np.empty((attacked_channel_list,len(iter_list),len(lambda_list),len(eps_list),dataset_params['runs']))

        for i in range(attacked_channel_list):

            v_array= np.empty((len(iter_list),len(lambda_list),len(eps_list),dataset_params['runs'],1,dataset_params['datapoints'])) if run_params['Head Model'] else np.empty((len(iter_list),len(lambda_list),len(eps_list),dataset_params['runs'],dataset_params['channels'],dataset_params['datapoints']))
            distances_metrics = np.empty((len(eps_list),len(eps_list)*3))
            
            run_params['Attacked channels']=[i] if multiple_channels else run_params['Attacked channels']

            for ind,(iterations,lambdas, eps) in enumerate(itertools.product(iter_list,lambda_list, eps_list)):
                run_params = add_run_constraints(run_params,eps,iterations,lambdas)
                output= generate_attack(device=device,
                                        test_name=test_name,
                                        run_params=run_params,
                                        dataset_params=dataset_params)
                if run_params['Assessment metrics']: 
                    attack_samples,samples,labels,prediction_labels,attack_labels,_ = output
                else:
                    asr,v = output
                    # Orders the asr_array as: (iteration,lambda,epsilon,run)
                    index_1 = ind//(len(lambda_list)*len(eps_list))
                    index_2 = (ind//(len(eps_list)))%len(lambda_list)
                    index_3 = ind%(len(eps_list))

                    asr_array[i,index_1, index_2, index_3]=asr

                    v_array[index_1, index_2, index_3]=v
                    print(f'Finished {ind+1}/{len(list(itertools.product(iter_list,lambda_list, eps_list)))} \
                            Channel: {i+1}/{attacked_channel_list} \
                            ASR: {asr.mean()*100:.2f}% \
                            Lambdas: {lambdas}')
            print(f'Restart: {restart+1}/{restarts} done')
            datapath = os.path.abspath(os.path.join(os.path.dirname( __file__ ), '2-Output'))
            
            if export:
                np.save(datapath+'/ASR_'+ file_name+'.npy', asr_array)  

                if run_params['UAP']:
                    t.save(v_array, datapath+'/v_'+file_name+'.pt')
            if metrics_export:
                return attack_samples,samples,labels,prediction_labels,attack_labels
                
def prepare_run(device):
    # Parameters of run
    iter_list = [10]
    run_params = {'Dataset':'PhysioNet',      # Str: 'PhysioNet' or 'BCI-Competition'
                    'UAP':False,                    # Bool: Universal Adversarial Perturbation or Instance-Based attack
                    'PGD': True,                    # Bool: PGD if True, otherwise FGSM
                    'Head Model': False,            # Bool: Spatial constraints 
                    'AblatedHM':False,              # Bool: Only works when Head Model is also active, tests on lambdas but trains with [0,0]
                    'Naturalism': False,            # Bool: Naturalism model: 'derivative', 'gaussian', False
                    'MaxEpochs':1,                  # Int: Number of epochs for the UAP training
                    'batch_size':16,                # Int: Batch size for the data loaders
                    'Attacked channels':[6],        # IntList:Channel to be attacked
                    'Perturbation type':'attack',   # Str: Decide whether an attack or baseline perturbation will be generated: 'attack' or 'random_noise'
                    'loss function':'Liu et al',    # Str: Loss function option for UAP: 'Liu et al', 'derivative_term'
                    'Assessment metrics':False}     # Bool: Only active when used for plotting
    run_params['Early Stop']=0                      # For plotting
    multiple_channels=False                         # Only for generating plot of multiple channels

    restarts =5
    eps_list = [1,5,10,25,50] # For PhysioNet 
    # eps_list=[0.1,1,2,5,10] # For BCI-Competition
    # eps_list=[0.01,0.05,0.1,0.25,0.5,0.75,1,2,5,10] # For vanilla BCI-Competition
    # eps_list=[1,2,3,4,5,6,7,8,9,10,25,50] # For vanilla and baseline PhysioNet attacks
    run_attack(run_params, iter_list,eps_list,restarts,device,multiple_channels=multiple_channels)

if __name__=='__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("cuda")
    args = parser.parse_args()
    device = t.device(f'cuda:{args.cuda}')
    prepare_run(device)