import sys
import os
from os.path import expanduser
HOME = expanduser("~")
from sklearn import preprocessing
from sklearn.preprocessing import MinMaxScaler
import numpy as np
import multiprocessing
from multiprocessing import Pool
import random
import pickle
import time

import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch.nn.functional as F

import pipeline_functions




################################################################################################################################################################################################################################################
################################################################################################################################################################################################################################################

################################################################################################################################################################################################################################################
################################################################################################################################################################################################################################################




"""
###
BEGINNING OF SCRIPT
###
"""




################################################################################################################################################################################################################################################
################################################################################################################################################################################################################################################

################################################################################################################################################################################################################################################
################################################################################################################################################################################################################################################








def main():
    
    ARGS = sys.argv[1:]
    print("ARGS:", ARGS)
    print("\n")
    
    DATASET, NORM_, TRAIN_N_JOBS = ARGS[0], ARGS[1], int(ARGS[2])

    if NORM_ == "2":
        NORM = None # // None -> defaults to the Euclidean norm 
        NORM_TAG = "L2"
    elif NORM_ == "inf":
        NORM = np.inf
        NORM_TAG = "LINF"

    BOOTSTRAP_ITERATIONS = 1  
    print("TRAIN_N_JOBS:", TRAIN_N_JOBS)

    SAVEPATH = HOME+"/ROB_BOX/SENS_LISTS/" 
    PLOTPATH = HOME+"/ROB_BOX/PLOTS/" 



    CALCULATE_SENSITIVITY_LISTS = True
    DETERMINE_DATA_REPRESENTATIONS = True


################################################################################################################################################################################################################################################
################################################################################################################################################################################################################################################

################################################################################################################################################################################################################################################
################################################################################################################################################################################################################################################




    """
    ###
    LOADING AND PREPROCESSING DATASET
    ###
    """




################################################################################################################################################################################################################################################
################################################################################################################################################################################################################################################

################################################################################################################################################################################################################################################
################################################################################################################################################################################################################################################




    print("Starting Timer:")
    print()
    start = time.time()




    """
    ***
    CIFAR-10
    ***
    """

    if DATASET == "cifar10":
        PATH = HOME+"/DATA/" 
        SET_NAME_FOR_SAVING = "cifar10"
        NUM_CLASSES = 10
        LENGTH = 50_000
        DIMENSIONS = 32 * 32 * 3
        train_dataset = datasets.CIFAR10(root=PATH, train=True, download=True, transform=transforms.Compose([transforms.ToTensor()]))
        train_loader = DataLoader(train_dataset, batch_size=LENGTH)
        X_train = next(iter(train_loader))[0].numpy()
        X_train = np.reshape(X_train, newshape=(len(X_train), DIMENSIONS)) #need to have vectors to get the right inf-norm from numpy
        y_train = F.one_hot(next(iter(train_loader))[1], num_classes=NUM_CLASSES).numpy()


    """
    ***
    CIFAR-100 - fine labels
    ***
    """

    if DATASET == "cifar100":
        PATH = HOME+"/DATA/" 
        SET_NAME_FOR_SAVING = "cifar100"
        NUM_CLASSES = 100
        LENGTH = 50_000
        DIMENSIONS = 32 * 32 * 3
        train_dataset = datasets.CIFAR100(root=PATH, train=True, download=True, transform=transforms.Compose([transforms.ToTensor()]))
        train_loader = DataLoader(train_dataset, batch_size=LENGTH)
        X_train = next(iter(train_loader))[0].numpy()
        X_train = np.reshape(X_train, newshape=(len(X_train), DIMENSIONS)) #need to have vectors to get the right inf-norm from numpy
        y_train = F.one_hot(next(iter(train_loader))[1], num_classes=NUM_CLASSES).numpy()


    """
    ***
    CIFAR-100s - coarse labels (thx to https://github.com/xiaodongww/pytorch/blob/master/cifarDataset.py)
    ***
    """

    if DATASET == "cifar100_super":
        PATH = HOME+"/DATA/" 
        SET_NAME_FOR_SAVING = "cifar100_super"
        NUM_CLASSES = 20
        LENGTH = 50_000
        DIMENSIONS = 32 * 32 * 3
        train_dataset = datasets.CIFAR100(root=PATH, train=True, download=True, transform=transforms.Compose([transforms.ToTensor()]), coarse=True)
        train_loader = DataLoader(train_dataset, batch_size=LENGTH)
        X_train = next(iter(train_loader))[0].numpy()
        X_train = np.reshape(X_train, newshape=(len(X_train), DIMENSIONS)) #need to have vectors to get the right inf-norm from numpy
        y_train = F.one_hot(next(iter(train_loader))[2], num_classes=NUM_CLASSES).numpy() #third list contains the coarse labels





    """
    ***
    ( https://github.com/wzekai99/DM-Improves-AT )

    @inproceedings{wang2023better,
      title={Better Diffusion Models Further Improve Adversarial Training},
      author={Wang, Zekai and Pang, Tianyu and Du, Chao and Lin, Min and Liu, Weiwei and Yan, Shuicheng},
      booktitle={International Conference on Machine Learning (ICML)},
      year={2023}
    }

    ***
    """

    if DATASET == "1m":
        PATH = HOME+"/DATA/" 
        SET_NAME_FOR_SAVING = "1m-1"
        NUM_CLASSES = 10
        LENGTH = 50_000
        DIMENSIONS = 32 * 32 * 3
        aux = np.load(PATH+"1m.npz")
        aux_data = aux['image']
        aux_label = aux['label']

        idx = np.random.permutation(len(aux_data))

        aux_data = aux_data[idx]
        aux_label = aux_label[idx]

        X_train = aux_data[:50_000]
        y_train = aux_label[:50_000]

        X_train = X_train / 255.0
        X_train = np.reshape(X_train, newshape=(len(X_train), DIMENSIONS)) #need to have vectors to get the right inf-norm from numpy

        lb = preprocessing.LabelBinarizer()
        y_train = lb.fit_transform(y_train)


   

    """
    ***
    ( https://github.com/yaircarmon/semisup-adv )

    @inproceedings{ti500k,  
        author = {Yair Carmon and Aditi Raghunathan and Ludwig Schmidt and Percy Liang and John Duchi},  
        title = {Unlabeled Data Improves Adversarial Robustness},  
        year = 2019,  
        booktitle = {Advances in Neural Information Processing Systems (NeurIPS)},  
    }  

    ***
    """


    if DATASET == "ti500k":
        PATH = HOME+"/DATA/" 
        SET_NAME_FOR_SAVING = "ti500k-1"
        NUM_CLASSES = 10
        LENGTH = 50_000
        DIMENSIONS = 32 * 32 * 3
        aux = np.load(PATH+"ti_500K_pseudo_labeled.pickle", allow_pickle=True)
        aux_data = aux['data']
        aux_label = aux['extrapolated_targets']

        idx = np.random.permutation(len(aux_data))

        aux_data = aux_data[idx]
        aux_label = aux_label[idx]

        X_train = aux_data[:50_000]
        y_train = aux_label[:50_000]

        X_train = X_train / 255.0
        X_train = np.reshape(X_train, newshape=(len(X_train), DIMENSIONS)) #need to have vectors to get the right inf-norm from numpy

        lb = preprocessing.LabelBinarizer()
        y_train = lb.fit_transform(y_train)        

    
    

    """
    ***
    EMNIST

    @inproceedings{emnist,
        author = {Cohen, Gregory and Afshar, Saeed and Tapson, Jonathan and van Schaik, André},
        year = {2017},
        month = {05},
        pages = {2921-2926},
        title = {EMNIST: Extending MNIST to handwritten letters},
        doi = {10.1109/IJCNN.2017.7966217}
    }

    ***
    """






    if DATASET == "e-mnist":
        PATH = HOME+"/DATA/" 
        SET_NAME_FOR_SAVING = "E-MNIST"
        SPLIT = "mnist"
        NUM_CLASSES = 10
        LENGTH = 60_000
        DIMENSIONS = 28 * 28 * 1
        train_dataset = datasets.EMNIST(root=PATH, train=True, download=True, split=SPLIT, transform=transforms.Compose([transforms.ToTensor()]))
        train_loader = DataLoader(train_dataset, batch_size=LENGTH)
        X_train = next(iter(train_loader))[0].numpy()
        X_train = np.reshape(X_train, newshape=(len(X_train), DIMENSIONS)) #need to have vectors to get the right inf-norm from numpy
        y_train = F.one_hot(next(iter(train_loader))[1], num_classes=NUM_CLASSES).numpy()




    if DATASET == "e-letters":
        PATH = HOME+"/DATA/" 
        SET_NAME_FOR_SAVING = "E-LETTERS"
        SPLIT = "letters"
        NUM_CLASSES = 37 #although there are only 26 different unique labels (no lower case letters)
        LENGTH = 88_800
        DIMENSIONS = 28 * 28 * 1
        train_dataset = datasets.EMNIST(root=PATH, train=True, download=True, split=SPLIT, transform=transforms.Compose([transforms.ToTensor()]))
        train_loader = DataLoader(train_dataset, batch_size=LENGTH)
        X_train = next(iter(train_loader))[0].numpy()
        X_train = np.reshape(X_train, newshape=(len(X_train), DIMENSIONS)) #need to have vectors to get the right inf-norm from numpy
        y_train = F.one_hot(next(iter(train_loader))[1], num_classes=NUM_CLASSES).numpy()




    if DATASET == "e-digits":
        PATH = HOME+"/DATA/" 
        SET_NAME_FOR_SAVING = "E-DIGITS"
        SPLIT = "digits"
        NUM_CLASSES = 10
        LENGTH = 240_000
        DIMENSIONS = 28 * 28 * 1
        train_dataset = datasets.EMNIST(root=PATH, train=True, download=True, split=SPLIT, transform=transforms.Compose([transforms.ToTensor()]))
        train_loader = DataLoader(train_dataset, batch_size=LENGTH)
        X_train = next(iter(train_loader))[0].numpy()
        X_train = np.reshape(X_train, newshape=(len(X_train), DIMENSIONS)) #need to have vectors to get the right inf-norm from numpy
        y_train = F.one_hot(next(iter(train_loader))[1], num_classes=NUM_CLASSES).numpy()



    if DATASET == "e-balanced":
        PATH = HOME+"/DATA/" 
        SET_NAME_FOR_SAVING = "E-BALANCED"
        SPLIT = "balanced"
        NUM_CLASSES = 47
        LENGTH = 112_800
        DIMENSIONS = 28 * 28 * 1
        train_dataset = datasets.EMNIST(root=PATH, train=True, download=True, split=SPLIT, transform=transforms.Compose([transforms.ToTensor()]))
        train_loader = DataLoader(train_dataset, batch_size=LENGTH)
        X_train = next(iter(train_loader))[0].numpy()
        X_train = np.reshape(X_train, newshape=(len(X_train), DIMENSIONS)) #need to have vectors to get the right inf-norm from numpy
        y_train = F.one_hot(next(iter(train_loader))[1], num_classes=NUM_CLASSES).numpy()



    if DATASET == "e-bymerge":
        PATH = HOME+"/DATA/" 
        SET_NAME_FOR_SAVING = "E-BYMERGE"
        SPLIT = "bymerge"
        NUM_CLASSES = 47
        LENGTH = 697_932
        DIMENSIONS = 28 * 28 * 1
        train_dataset = datasets.EMNIST(root=PATH, train=True, download=True, split=SPLIT, transform=transforms.Compose([transforms.ToTensor()]))
        train_loader = DataLoader(train_dataset, batch_size=LENGTH)
        X_train = next(iter(train_loader))[0].numpy()
        X_train = np.reshape(X_train, newshape=(len(X_train), DIMENSIONS)) #need to have vectors to get the right inf-norm from numpy
        y_train = F.one_hot(next(iter(train_loader))[1], num_classes=NUM_CLASSES).numpy()



    if DATASET == "e-byclass":
        PATH = HOME+"/DATA/" 
        SET_NAME_FOR_SAVING = "E-BYCLASS"
        SPLIT = "byclass"
        NUM_CLASSES = 62
        LENGTH = 697_932
        DIMENSIONS = 28 * 28 * 1
        train_dataset = datasets.EMNIST(root=PATH, train=True, download=True, split=SPLIT, transform=transforms.Compose([transforms.ToTensor()]))
        train_loader = DataLoader(train_dataset, batch_size=LENGTH)
        X_train = next(iter(train_loader))[0].numpy()
        X_train = np.reshape(X_train, newshape=(len(X_train), DIMENSIONS)) #need to have vectors to get the right inf-norm from numpy
        y_train = F.one_hot(next(iter(train_loader))[1], num_classes=NUM_CLASSES).numpy()





    print()
    print("Set used:", SET_NAME_FOR_SAVING)
    print()

    print()
    print("Dataset-Example 0:", X_train[0])
    print()

    print()
    print("Label-Example 0:", y_train[0])
    print()



################################################################################################################################################################################################################################################
################################################################################################################################################################################################################################################

################################################################################################################################################################################################################################################
################################################################################################################################################################################################################################################




    """
    ###
    BEGINNING OF DATA PREPROCESSING
    ###
    """




################################################################################################################################################################################################################################################
################################################################################################################################################################################################################################################

################################################################################################################################################################################################################################################
################################################################################################################################################################################################################################################


    len_train = len(X_train)
    
    X_train = np.array(X_train, dtype="float32")

    NR_OF_FEATURES = len(X_train[0])
    print("NR_OF_FEATURES:", NR_OF_FEATURES)
    print()

    y_train = np.array(y_train, dtype="float32")

    NR_OF_CLASSES = len(np.unique(y_train, axis=0))
    print("NR_OF_CLASSES:", NR_OF_CLASSES)
    print()
        

    
  
            
            
################################################################################################################################################################################################################################################
################################################################################################################################################################################################################################################

################################################################################################################################################################################################################################################
################################################################################################################################################################################################################################################




    """
    ###
    ROBUSTNESS CALCULATIONS 
    ###
    """




################################################################################################################################################################################################################################################
################################################################################################################################################################################################################################################

################################################################################################################################################################################################################################################
################################################################################################################################################################################################################################################
           
            
    if CALCULATE_SENSITIVITY_LISTS:


        #can help to produce smoother distributions via bootstrapping samples of the whole training set (BOOTSTRAP_ITERATIONS=1 yields the original setup)
        BOOTSTRAPS_LENGTH_TRAIN = int(len(X_train) / BOOTSTRAP_ITERATIONS)

        jump = int(len(X_train) / TRAIN_N_JOBS) + 1
        jump_list = [jump * k for k in range(TRAIN_N_JOBS)]
        jump_list.append(len(X_train))

        with Pool(processes=TRAIN_N_JOBS) as pool:
            sens_vals_parts = pool.starmap(pipeline_functions.sens_calc, [
                (X_train, 
                 y_train, 
                 X_train[jump_list[k] : jump_list[k+1]],
                 y_train[jump_list[k] : jump_list[k+1]],
                 NORM,
                 None,
                 "CLA",
                 BOOTSTRAP_ITERATIONS,
                 BOOTSTRAPS_LENGTH_TRAIN) for k in range(TRAIN_N_JOBS)])


        sens_vals = np.array([], dtype="float32")
        for p in range(TRAIN_N_JOBS):
            sens_vals = np.concatenate([sens_vals, sens_vals_parts[p]])

        sens_indices_ordered = sorted(range(len(sens_vals)), key=lambda k: sens_vals[k])


        with open(SAVEPATH + SET_NAME_FOR_SAVING + "_sens_vals_"+ NORM_TAG+".pickle", 'wb') as f:
            pickle.dump(sens_vals, f)
        with open(SAVEPATH + SET_NAME_FOR_SAVING + "_sens_indices_ordered_"+ NORM_TAG+".pickle", 'wb') as f:
            pickle.dump(sens_indices_ordered, f)

        print()     
        print("The (approximate) maximal Lipschitz constant of the label functions:", sens_vals[sens_indices_ordered][-1])

        print()
        print("Finished!!!", "Time Elapsed:", round(time.time()-start, 2), "Secs")
        print()






    ################################################################################################################################################################################################################################################
    ################################################################################################################################################################################################################################################

    ################################################################################################################################################################################################################################################
    ################################################################################################################################################################################################################################################
               
    if DETERMINE_DATA_REPRESENTATIONS:     

        with open(SAVEPATH+SET_NAME_FOR_SAVING+'_sens_vals_' + str(NORM_TAG) + '.pickle', 'rb') as f:
            sens_vals = pickle.load(f)
        with open(SAVEPATH+SET_NAME_FOR_SAVING+'_sens_indices_ordered_' + str(NORM_TAG) + '.pickle', 'rb') as f:
            sens_indices_ordered = pickle.load(f)


        pipeline_functions.manifold_embedder_and_plotter(
            "", 
            X_train, 
            sens_vals, 
            sens_indices_ordered, 
            n_neighs=10, 
            _n_jobs_=TRAIN_N_JOBS, 
            show_graphs=False, 
            path=PLOTPATH+SET_NAME_FOR_SAVING+"_", 
            i_t_r=NORM_TAG, 
            sample_size=1000)









if __name__ == '__main__':
    print("Starting Pipeline")
    print()
    main()
    print("The End")



