import os
import gdown
import urllib.request
import hashlib
import zipfile
import shutil
import tarfile

def check_dirs(output):
    dirs = output.split('/')
    dirs_path = '/'.join(dirs[:-1])
    if not os.path.exists(dirs_path):
        os.makedirs(dirs_path)

# pip install gdown
def gdown_maybe(url, output):
    check_dirs(output)
    print(f"Checking if we need {output}")
    if not os.path.exists(output):
        gdown.download(url, output, quiet=False)
        
def urllib_maybe(url, output):
    check_dirs(output)
    print(f"Checking if we need {output}")
    if not os.path.exists(output):
        urllib.request.urlretrieve(url, output)

def print_hash(output):
    sha256 = hashlib.sha256()

    with open(output, 'rb') as f:
        data = f.read()
        sha256.update(data)

    got_hash = sha256.hexdigest()
    print(f'SHA256 hash: {got_hash}')
    
def check_hash(true, output):
    sha256 = hashlib.sha256()

    with open(output, 'rb') as f:
        data = f.read()
        sha256.update(data)
        
    got_hash = sha256.hexdigest()
    print(f'SHA256 hash: {got_hash}')
        
    if true != got_hash:
        raise ValueError(f"The hash for {output} was incorrect! Aborting.")

    
# TRADES: Theoretically Principled Trade-off between Robustness and Accuracy
url = 'https://drive.google.com/uc?id=10sHvaXhTNZGz618QmD5gSOAjO3rMzV33'
output = 'ckpt-trades/model_cifar_wrn.pt'
gdown_maybe(url, output)

check_hash('2ede52bd042bbdf40a0c27e8008034afd9cbb0b256b9077a255e555d25f957f4', output)


# SENSE: Sensible adversarial learning
url = 'https://drive.google.com/uc?id=1hKf5-PLteFFZgrXdes6o9GC7M1MED94m'
output = 'ckpt-sense/SENSE_checkpoint300.dict'
gdown_maybe(url, output)
check_hash('c0202c4fe693fd89adf1b9c4bb4f1c8d36c23befa779047dc60ed7243a1e7c06', output)


# Feature-Scattering: Defense against adversarial attacks using feature scattering-based adversarial training
url = 'https://drive.google.com/uc?id=1FXgE7llvQoypf7iCGR680EKQf9cARTSg'
output = 'ckpt-fs/checkpoint-199-ipot'
gdown_maybe(url, output)
check_hash('4aab898ddbff9c92197c73cd012a94a4f7c79221eb62fcbcd641f5591216eba7', output)


# Adversarial Interpolation Training: Adversarial Interpolation Training: A Simple Approach for Improving Model Robustness
url = 'https://drive.google.com/uc?id=1NWYmLAArzstzaknO1L0ZMni0_8UaSpTo'
output = 'ckpt-interp/latest'
gdown_maybe(url, output)
check_hash('26619aeaf2bb67ac3cb62815fdf7b21c9723dc7875617f9f2360ed0de9b709db', output)


# Madry CIFAR-10: https://github.com/MadryLab/cifar10_challenge
url = 'https://www.dropbox.com/s/g4b6ntrp8zrudbz/adv_trained.zip?dl=1'
fname = url.split('/')[-1].split('?')[0]
output = f"ckpt-madry/{fname}"

if not os.path.exists(output):
    print("Downloading Madry resources")
    
    urllib_maybe(url, output)
    check_hash("996384ac78ec749673e43cb823b9fb837a4da8d529ed2ef98e8d94053529fa5d", output)
    
    # Post-process files
    with zipfile.ZipFile(output, 'r') as model_zip:
        model_zip.extractall("ckpt-madry/")
        print('Extracted model in {}'.format(model_zip.namelist()[0]))
        
    for filep in os.listdir("ckpt-madry/models/adv_trained"):
        shutil.copy2(f"ckpt-madry/models/adv_trained/{filep}", f"ckpt-madry/{filep}")
        
    print("Cleaning up...")
    shutil.rmtree("ckpt-madry/models")

print("Madry CIFAR-10")
check_hash('996384ac78ec749673e43cb823b9fb837a4da8d529ed2ef98e8d94053529fa5d', output)


# smooth : Certified Adversarial Robustness via Randomized Smoothing
url = 'https://drive.google.com/uc?id=1h_TpbXm5haY5f-l4--IKylmdz6tvPoR4'
output = 'smoothing_models.tar'
if not os.path.exists("ckpt-smooth"):
    print("Downloading certified smoothing resources")
    
    gdown.download(url, output, quiet=False)
    with tarfile.open(output) as tar:
        tar.extractall()
    
    os.rename("models", "ckpt-smooth")
    
print("Smoothing")
check_hash('833db0f5dab02091ed3e8412a25aa945fa84a34663bfd69ecc0f1051bb720eb0', output)

if not os.path.exists('ckpt-jalal'):
    os.makedirs('ckpt-jalal')

# The Robust Manifold Defense: Adversarial Training using Generative Models
# https://github.com/ajiljalal/manifold-defense/tree/master/adv-mnist/results    
# VAE
print("Robust manifold VAE (MNIST)")
url = 'https://github.com/ajiljalal/manifold-defense/blob/master/adv-mnist/checkpoints/trained_vae_leakyrelu_20_500_500_784.pth?raw=true'
output = 'ckpt-jalal/trained_vae_leakyrelu_20_500_500_784.pth'
urllib_maybe(url, output)
check_hash('4b34c8e86b0950379cd78f8edcd0a904af14bbcd04423368a93db605d227912c', output)


# Towards deep learning models resistant to adversarial attacks
print("Madry (MNIST)")
url = 'https://github.com/ajiljalal/manifold-defense/blob/master/adv-mnist/results/mnist_l2_baseline_best?raw=true'
output = 'ckpt-jalal/mnist_l2_baseline_best'
urllib_maybe(url, output)
check_hash('c193987e09ea543db57a816ecd81d660ea1d7b3455eff2954a491e0224e0cb96', output)


# The Robust Manifold Defense: Adversarial Training using Generative Models
print("Robust manifold (MNIST)")
url = 'https://github.com/ajiljalal/manifold-defense/blob/master/adv-mnist/results/mnist_l2_op_best?raw=true'
output = 'ckpt-jalal/mnist_l2_op_best'
urllib_maybe(url, output)
check_hash('d8334107a72b8b6263caf42042352a36f2cba516f818389ee4332c08342e14a5', output)


# Theoretically Principled Trade-off between Robustness and Accuracy
print("TRADES (MNIST)")
url = 'https://github.com/ajiljalal/manifold-defense/blob/master/adv-mnist/results/mnist_l2_trades_6.pt?raw=true'
output = 'ckpt-jalal/mnist_l2_trades_6.pt'
urllib_maybe(url, output)
check_hash('b99d5f3f9bf90eff0facb231a1deaa7b79a494cdd09cce0ea42cae8380941376', output)


# if not os.path.exists('dbxcli-linux-amd64'):
#     os.command("wget https://github.com/dropbox/dbxcli/releases/download/v3.0.0/dbxcli-linux-amd64")
print("Madry (ImageNet 8/255)")
url = 'https://www.dropbox.com/s/yxn15a9zklz3s8q/imagenet_linf_8.pt?dl=1'
output = 'ckpt-imagenet-madry/imagenet_linf_8.pt'
urllib_maybe(url, output)
check_hash('6a0bf683b4a46c0512058fd6e30a4bd3b493dd02a651fa13b647a7f57906f1fb', output)

print("All hash were correct.")


