import numpy as np
import os

SEED = 9 + 1
np.random.seed(SEED)

project_root = "<anonymized>/hard_label_manifolds"

global_shell_commands_before = [
    f"cd {project_root}",
    f"conda activate hm",
    f"export PYTHONPATH={project_root}:{project_root}/community/causal_robustness/"
]

# Natural models
# datasets = ["Imagenet"]
# datasets = ["CIFAR10"]
# datasets = ["MNIST"]

# Robust models
robust_mnist_l2_suffix = [
    "madry",           # Towards deep learning models resistant to adversarial attacks
#     "trades",          # Theoretically Principled Trade-off between Robustness and Accuracy
#     "rob_manifold",    # The Robust Manifold Defense: Adversarial Training using Generative Models
    "deep_camma",      # A Causal View on Robustness
]

robust_cifar_linf_suffix = [
    "trades",    # Theoretically Principled Trade-off between Robustness and Accuracy
    "madry",     # Towards deep learning models resistant to adversarial attack
    "interp",    # Adversarial Interpolation Training: A Simple Approach for Improving Model Robustness
    "fs",        # Defense against adversarial attacks using feature scattering-based adversarial training
    "sense",     # Sensible adversarial learning
]

robust_cifar_l2_suffix = [
    # (not available) "smooth20",  # Certified Adversarial Robustness via Randomized Smoothing 
    "smooth110", # Certified Adversarial Robustness via Randomized Smoothing
]

robust_imagenet_l2_suffix = [
    'smooth50',  # Certified Adversarial Robustness via Randomized Smoothing 
]

robust_imagenet_linf_suffix = [
    'madry8',  # Towards deep learning models resistant to adversarial attacks (eps=8/255.)
#     'madry4',  # Towards deep learning models resistant to adversarial attacks (eps=4/255.)
]

robust_l2_mnist = [f"MNIST_{suffix}" for suffix in robust_mnist_l2_suffix]
robust_linf_cifar = [f"CIFAR10_{suffix}" for suffix in robust_cifar_linf_suffix]
robust_l2_cifar = [f"CIFAR10_{suffix}" for suffix in robust_cifar_l2_suffix]
robust_l2_imagenet = [f"Imagenet_{suffix}" for suffix in robust_imagenet_l2_suffix]
robust_linf_imagenet = [f"Imagenet_{suffix}" for suffix in robust_imagenet_linf_suffix]

# datasets = robust_l2_cifar + robust_l2_imagenet
# datasets = robust_l2_imagenet

# ================= checklist =================
# datasets = ["MNIST"] + robust_l2_mnist
# datasets = ["CIFAR10"] + robust_linf_cifar
# datasets = ["Imagenet"] + robust_l2_imagenet
datasets = ["Imagenet"] + robust_linf_imagenet

# datasets = robust_linf_cifar
early_stop = False
target_bools = [False]
# target_bools = [True]

query_limit = 25000

# norms = [2]
norms = ['inf']


# normal_attacks = ["OPT_attack", "Sign_OPT", "RayS", "HSJA"]
normal_attacks = ["Sign_OPT", "HSJA"]

att_pretty = {
#     "OPT_attack": "OPT", 
#     "Sampling_OPT_attack": "BiLN OPT", 
#     "RandSampling_OPT_attack": "Rand. OPT",
#     "HLM_OPT_attack": "HLM OPT", 
    "Sign_OPT": "Sign-OPT",
    "Sampling_Sign_OPT": "BiLN Sign-OPT",
    "RandSampling_Sign_OPT": "Rand. Sign-OPT",
    "HLM_Sign_OPT": "HLM Sign-OPT",
#     "RayS": "RayS",
#     "Sampling_RayS": "BiLN RayS",
#     "HLM_RayS": "HLM RayS",
    "HSJA": "HSJA",
    "Sampling_HSJA": "BiLN HSJA",
    "RandSampling_HSJA": "Rand. HSJA",
#     "HLM_HSJA": "HLM HSJA",
}

att_raw = list(att_pretty.keys())
    
RayS_a = [2]  # upscale factor
RayS_b = [2, 4]  # downscale factor

dataset_to_path = {
    'MNIST': "",
    'CIFAR10': "",
    'Imagenet': "<anonymized>/IMAGENET/Imagenet2012"
}
for k in robust_l2_mnist: dataset_to_path[k] = dataset_to_path['MNIST']
for k in robust_linf_cifar: dataset_to_path[k] = dataset_to_path['CIFAR10']
for k in robust_l2_cifar: dataset_to_path[k] = dataset_to_path['CIFAR10']
for k in robust_l2_imagenet: dataset_to_path[k] = dataset_to_path['Imagenet']
for k in robust_linf_imagenet: dataset_to_path[k] = dataset_to_path['Imagenet']

dataset_to_victim_path = {
    'MNIST': f"{project_root}/victim_models/ckpt/MNIST/mnist_epoch-50_seed-9.ckpt",
    'MNIST_madry': f"{project_root}/community/ckpt-jalal/mnist_l2_baseline_best", 
    'MNIST_trades': f"{project_root}/community/ckpt-jalal/mnist_l2_trades_6.pt",
    'MNIST_rob_manifold': f"{project_root}/community/ckpt-jalal/mnist_l2_op_best",
    # 'MNIST_deep_camma': f"{project_root}/community/causal_robustness/Model/do_m_vertical_shift_0.25/deep_camma_do_m_27_04_2021_12_11_46_epochs_35.pth",
    'CIFAR10': f"{project_root}/victim_models/ckpt/CIFAR10/cifar_epoch-50_seed-9.ckpt",
    'CIFAR10_trades': f"{project_root}/community/ckpt-trades/model_cifar_wrn.pt",
    'CIFAR10_madry': f"{project_root}/community/ckpt-madry/",
    'CIFAR10_interp': f"{project_root}/community/ckpt-interp/latest",
    'CIFAR10_fs': f"{project_root}/community/ckpt-fs/checkpoint-199-ipot",
    'CIFAR10_sense': f"{project_root}/community/ckpt-sense/SENSE_checkpoint300.dict",
    'CIFAR10_smooth20': f"",  # Not available
    'CIFAR10_smooth110': f"{project_root}/community/ckpt-smooth/cifar10/resnet110/noise_0.25/checkpoint.pth.tar",
    'Imagenet': "",  # torchvision pretrained
    'Imagenet_madry8': f"{project_root}/community/ckpt-imagenet-madry/imagenet_linf_8.pt", 
    'Imagenet_madry4': f"{project_root}/community/ckpt-imagenet-madry/imagenet_linf_4.pt", 
    'Imagenet_smooth50': f"{project_root}/community/ckpt-smooth/imagenet/resnet50/noise_0.50/checkpoint.pth.tar", 
}

hlm_architectures = ['AE']

dataset_to_ae = {
    ('MNIST', 'enc'): f"{project_root}/generator/ckpt/MNIST/conv_encoder_epoch-100.pth",
    ('MNIST', 'dec'): f"{project_root}/generator/ckpt/MNIST/conv_decoder_epoch-100.pth",
    ('CIFAR10', 'enc'): f"{project_root}/generator/ckpt/CIFAR10/conv_encoder_epoch-28.pth",
    ('CIFAR10', 'dec'): f"{project_root}/generator/ckpt/CIFAR10/conv_decoder_epoch-28.pth",
    ('Imagenet', 'enc'): f"{project_root}/generator/ckpt/Imagenet/conv_encoder_epoch-100.pth",
    ('Imagenet', 'dec'): f"{project_root}/generator/ckpt/Imagenet/conv_decoder_epoch-100.pth"
    
}
for k in robust_l2_mnist: dataset_to_ae[(k, 'enc')] = dataset_to_ae[('MNIST', 'enc')]
for k in robust_l2_mnist: dataset_to_ae[(k, 'dec')] = dataset_to_ae[('MNIST', 'dec')]
for k in robust_linf_cifar: dataset_to_ae[(k, 'enc')] = dataset_to_ae[('CIFAR10', 'enc')]
for k in robust_linf_cifar: dataset_to_ae[(k, 'dec')] = dataset_to_ae[('CIFAR10', 'dec')]
for k in robust_l2_cifar: dataset_to_ae[(k, 'enc')] = dataset_to_ae[('CIFAR10', 'enc')]
for k in robust_l2_cifar: dataset_to_ae[(k, 'dec')] = dataset_to_ae[('CIFAR10', 'dec')]
for k in robust_l2_imagenet: dataset_to_ae[(k, 'enc')] = dataset_to_ae[('Imagenet', 'enc')]
for k in robust_l2_imagenet: dataset_to_ae[(k, 'dec')] = dataset_to_ae[('Imagenet', 'dec')]
for k in robust_linf_imagenet: dataset_to_ae[(k, 'enc')] = dataset_to_ae[('Imagenet', 'enc')]
for k in robust_linf_imagenet: dataset_to_ae[(k, 'dec')] = dataset_to_ae[('Imagenet', 'dec')]
    
dataset_to_cc_ae = {
    'MNIST': "<anonymized>/ckpt/hlm/generator/MNIST/091820-080110",
    'CIFAR10': "<anonymized>/ckpt/hlm/generator/CIFAR10/091920-012319",
    'Imagenet': "<anonymized>/ckpt/hlm/generator/Imagenet/091720-033828/",
}
cc_epoch = 80
for k in robust_l2_mnist: dataset_to_cc_ae[k] = dataset_to_cc_ae['MNIST']
for k in robust_linf_cifar: dataset_to_cc_ae[k] = dataset_to_cc_ae['CIFAR10']
for k in robust_l2_cifar: dataset_to_cc_ae[k] = dataset_to_cc_ae['CIFAR10']
for k in robust_l2_imagenet: dataset_to_cc_ae[k] = dataset_to_cc_ae['Imagenet']
for k in robust_linf_imagenet: dataset_to_cc_ae[k] = dataset_to_cc_ae['Imagenet']

dataset_to_compress_mode = {
    'MNIST': 2,
    'CIFAR10': 2,
    'Imagenet': 3
}
for k in robust_l2_mnist: dataset_to_compress_mode[k] = dataset_to_compress_mode['MNIST']
for k in robust_linf_cifar: dataset_to_compress_mode[k] = dataset_to_compress_mode['CIFAR10']
for k in robust_l2_cifar: dataset_to_compress_mode[k] = dataset_to_compress_mode['CIFAR10']
for k in robust_l2_imagenet: dataset_to_compress_mode[k] = dataset_to_compress_mode['Imagenet']
for k in robust_linf_imagenet: dataset_to_compress_mode[k] = dataset_to_compress_mode['Imagenet']

ae_compress_modes = [1]


dataset_to_resizings = {
    'MNIST': [14],
    'CIFAR10': [16],
    'Imagenet': [32]
}
for k in robust_l2_mnist: dataset_to_resizings[k] = dataset_to_resizings['MNIST']
for k in robust_linf_cifar: dataset_to_resizings[k] = dataset_to_resizings['CIFAR10']
for k in robust_l2_cifar: dataset_to_resizings[k] = dataset_to_resizings['CIFAR10']
for k in robust_l2_imagenet: dataset_to_resizings[k] = dataset_to_resizings['Imagenet']
for k in robust_linf_imagenet: dataset_to_resizings[k] = dataset_to_resizings['Imagenet']

dataset_to_num_classes = {
    'MNIST': 10,
    'CIFAR10': 10,
    'Imagenet': 1000
}
for k in robust_l2_mnist: dataset_to_num_classes[k] = dataset_to_num_classes['MNIST']
for k in robust_linf_cifar: dataset_to_num_classes[k] = dataset_to_num_classes['CIFAR10']
for k in robust_l2_cifar: dataset_to_num_classes[k] = dataset_to_num_classes['CIFAR10']
for k in robust_l2_imagenet: dataset_to_num_classes[k] = dataset_to_num_classes['Imagenet']
for k in robust_linf_imagenet: dataset_to_num_classes[k] = dataset_to_num_classes['Imagenet']

# Num classes to pick from in each dataset
dataset_to_work = {
    'MNIST': 10,
    'CIFAR10': 10,
    'Imagenet': 10
}
for k in robust_l2_mnist: dataset_to_work[k] = dataset_to_work['MNIST']
for k in robust_linf_cifar: dataset_to_work[k] = dataset_to_work['CIFAR10']
for k in robust_l2_cifar: dataset_to_work[k] = dataset_to_work['CIFAR10']
for k in robust_l2_imagenet: dataset_to_work[k] = dataset_to_work['Imagenet']
for k in robust_linf_imagenet: dataset_to_work[k] = dataset_to_work['Imagenet']


dataset_to_samples_per_class = {
    'MNIST': 2,
    'CIFAR10': 5,
    'Imagenet': 10
}
for k in robust_l2_mnist: dataset_to_samples_per_class[k] = dataset_to_samples_per_class['MNIST']
for k in robust_linf_cifar: dataset_to_samples_per_class[k] = dataset_to_samples_per_class['CIFAR10']
for k in robust_l2_cifar: dataset_to_samples_per_class[k] = dataset_to_samples_per_class['CIFAR10']
for k in robust_l2_imagenet: dataset_to_samples_per_class[k] = dataset_to_samples_per_class['Imagenet']
for k in robust_linf_imagenet: dataset_to_samples_per_class[k] = dataset_to_samples_per_class['Imagenet']

    
dataset_to_batch = {dataset: dataset_to_samples_per_class[dataset] * dataset_to_work[dataset] for dataset in datasets}
for k in robust_l2_mnist: dataset_to_batch[k] = dataset_to_samples_per_class[k] * dataset_to_work[k]
for k in robust_linf_cifar: dataset_to_batch[k] = dataset_to_samples_per_class[k] * dataset_to_work[k]
for k in robust_l2_cifar: dataset_to_batch[k] = dataset_to_samples_per_class[k] * dataset_to_work[k]
for k in robust_l2_imagenet: dataset_to_batch[k] = dataset_to_samples_per_class[k] * dataset_to_work[k]
for k in robust_linf_imagenet: dataset_to_batch[k] = dataset_to_samples_per_class[k] * dataset_to_work[k]

k2ix_database_root = os.path.join(project_root, "research_pool")

dataset_to_database_path = {
    'MNIST': os.path.join(k2ix_database_root, f"MNIST_indices.pkl"),
    'CIFAR10': os.path.join(k2ix_database_root, f"CIFAR10_indices.pkl"),
    'Imagenet': os.path.join(k2ix_database_root, f"Imagenet_indices.pkl")
}
for k in robust_l2_mnist: dataset_to_database_path[k] = dataset_to_database_path['MNIST']
for k in robust_linf_cifar: dataset_to_database_path[k] = dataset_to_database_path['CIFAR10']
for k in robust_l2_cifar: dataset_to_database_path[k] = dataset_to_database_path['CIFAR10']
for k in robust_l2_imagenet: dataset_to_database_path[k] = dataset_to_database_path['Imagenet']
for k in robust_linf_imagenet: dataset_to_database_path[k] = dataset_to_database_path['Imagenet']


def _choose_K_from(num_classes, Kp):
    return list(np.random.choice(list(range(num_classes)), Kp).astype(int))

# Decide random classes for each dataset.
mnist_cls    = _choose_K_from(dataset_to_num_classes['MNIST'], dataset_to_work['MNIST'])
cifar_cls    = _choose_K_from(dataset_to_num_classes['CIFAR10'], dataset_to_work['CIFAR10'])
imagenet_cls = _choose_K_from(dataset_to_num_classes['Imagenet'], dataset_to_work['Imagenet'])

dataset_to_classes = {
    'MNIST': mnist_cls,
    'CIFAR10': cifar_cls,
    'Imagenet': imagenet_cls
}
for k in robust_l2_mnist: dataset_to_classes[k] = dataset_to_classes['MNIST']
for k in robust_linf_cifar: dataset_to_classes[k] = dataset_to_classes['CIFAR10']
for k in robust_l2_cifar: dataset_to_classes[k] = dataset_to_classes['CIFAR10']
for k in robust_l2_imagenet: dataset_to_classes[k] = dataset_to_classes['Imagenet']
for k in robust_linf_imagenet: dataset_to_classes[k] = dataset_to_classes['Imagenet']
    

imagenet_archs = ["resnet50"]
# imagenet_split = 0.50


order_to_dataset_epsilon = {
    # SignOPT paper
    (2, 'MNIST'): 1.5,
    (2, 'MNIST_madry'): 1.5,  # https://github.com/ajiljalal/manifold-defense/tree/master/adv-mnist
    (2, 'MNIST_trades'): 1.5,  # https://github.com/ajiljalal/manifold-defense/tree/master/adv-mnist  
    (2, 'MNIST_rob_manifold'): 1.5,  # https://github.com/ajiljalal/manifold-defense/tree/master/adv-mnist
    (2, 'MNIST_deep_camma'): 1.5,
    (2, 'CIFAR10'): 0.5,
    (2, 'CIFAR10_smooth20'): 0.2,  # sigma=0.25
    (2, 'CIFAR10_smooth110'): 0.2,  # sigma=0.25
    (2, 'Imagenet'): 3.0,
    (2, 'Imagenet_smooth50'): 1.0,  # sigma=0.5
    # RayS paper
    ('inf', 'MNIST'): 0.3,
    ('inf', 'CIFAR10'): 0.031,
    ('inf', 'CIFAR10_trades'): 0.031,
    ('inf', 'CIFAR10_madry'): 0.031,
    ('inf', 'CIFAR10_interp'): 0.031,
    ('inf', 'CIFAR10_fs'): 0.031,
    ('inf', 'CIFAR10_sense'): 0.031,
    ('inf', 'Imagenet'): 0.031,
    ('inf', 'Imagenet_madry8'): 0.031,
    ('inf', 'Imagenet_madry4'): 0.015,
}

