"""Test all the examples before release.

This script is expected be manually run and is not used in automatic tests."""

import pytest
import subprocess
import os
import sys
import shlex
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--test', type=str, default=None)
args = parser.parse_args()

pytest_skip = pytest.mark.skip(
    reason="It should be tested on a GPU server and excluded from CI")

if not 'CACHE_DIR' in os.environ:
    cache_dir = os.path.join(os.getcwd(), '.cache')
else:
    cache_dir = os.environ['CACHE_DIR']
if not os.path.exists(cache_dir):
    os.makedirs(cache_dir)

def download_data_language():
    url = "http://download.huan-zhang.com/datasets/language/data_language.tar.gz"
    if not os.path.exists('../examples/language/data/sst'):
        subprocess.run(shlex.split(f"wget {url}"), cwd="../examples/language")
        subprocess.run(shlex.split(f"tar xvf data_language.tar.gz"),
            cwd="../examples/language")

@pytest_skip
def test_transformer():
    cmd = f"""python train.py --dir {cache_dir} --robust
        --method IBP+backward_train --train --num_epochs 2 --num_epochs_all_nodes 2
        --eps_start 2 --eps_length 1 --eps 0.1"""
    print(cmd, file=sys.stderr)
    download_data_language()
    subprocess.run(shlex.split(cmd), cwd='../examples/language')

@pytest_skip
def test_lstm():
    cmd = f"""python train.py --dir {cache_dir}
        --model lstm --lr 1e-3 --dropout 0.5 --robust
        --method IBP+backward_train --train --num_epochs 2 --num_epochs_all_nodes 2
        --eps_start 2 --eps_length 1 --eps 0.1
        --hidden_size 2 --embedding_size 2 --intermediate_size 2 --max_sent_length 4"""
    print(cmd, file=sys.stderr)
    download_data_language()
    subprocess.run(shlex.split(cmd), cwd='../examples/language')

@pytest_skip
def test_lstm_seq():
    cmd = f"""python train.py --dir {cache_dir}
        --hidden_size 2 --num_epochs 2 --num_slices 4"""
    print(cmd, file=sys.stderr)
    subprocess.run(shlex.split(cmd), cwd='../examples/sequence')

@pytest_skip
def test_simple_verification():
    cmd = "python simple_verification.py"
    print(cmd, file=sys.stderr)
    subprocess.run(shlex.split(cmd), cwd='../examples/vision')

@pytest_skip
def test_custom_op():
    cmd = "python custom_op.py"
    print(cmd, file=sys.stderr)
    subprocess.run(shlex.split(cmd), cwd='../examples/vision')

@pytest_skip
def test_efficient_convolution():
    cmd = "python efficient_convolution.py"
    print(cmd, file=sys.stderr)
    subprocess.run(shlex.split(cmd), cwd='../examples/vision')

@pytest_skip
def test_two_node():
    cmd = "python verify_two_node.py"
    print(cmd, file=sys.stderr)
    subprocess.run(shlex.split(cmd), cwd='../examples/vision')

@pytest_skip
def test_simple_training():
    cmd = """python simple_training.py
        --num_epochs 5 --scheduler_opts start=2,length=2"""
    print(cmd, file=sys.stderr)
    subprocess.run(shlex.split(cmd), cwd='../examples/vision')

@pytest_skip
def test_cifar_training():
    cmd = """python cifar_training.py
        --batch_size 64 --model ResNeXt_cifar
        --num_epochs 5 --scheduler_opts start=2,length=2"""
    print(cmd, file=sys.stderr)
    subprocess.run(shlex.split(cmd), cwd='../examples/vision')

@pytest_skip
def test_weight_perturbation():
    cmd = """python weight_perturbation_training.py
        --norm 2 --bound_type CROWN-IBP
        --num_epochs 3 --scheduler_opts start=2,length=1 --eps 0.01"""
    print(cmd, file=sys.stderr)
    subprocess.run(shlex.split(cmd), cwd='../examples/vision')

@pytest_skip
def test_tinyimagenet():
    cmd = f"""python tinyimagenet_training.py
        --batch_size 32 --model wide_resnet_imagenet64
        --num_epochs 3 --scheduler_opts start=2,length=1 --eps {0.1/255}
        --in_planes 2 --widen_factor 2"""
    print(cmd, file=sys.stderr)
    if not os.path.exists('../examples/vision/data/tinyImageNet/tiny-imagenet-200'):
        subprocess.run(shlex.split("bash tinyimagenet_download.sh"),
        cwd="../examples/vision/data/tinyImageNet")
    subprocess.run(shlex.split(cmd), cwd='../examples/vision')

@pytest_skip
def test_imagenet():
    cmd = f"""python imagenet_training.py
        --batch_size 32 --model wide_resnet_imagenet64_1000class
        --num_epochs 3 --scheduler_opts start=2,length=1 --eps {0.1/255}
        --in_planes 2 --widen_factor 2"""
    print(cmd)
    if (not os.path.exists('../examples/vision/data/ImageNet64/train') or
            not os.path.exists('../examples/vision/data/ImageNet64/test')):
        print('Error: ImageNet64 dataset is not ready.')
        return -1
    subprocess.run(shlex.split(cmd), cwd='../examples/vision')

@pytest_skip
def test_release():
    """Run all tests."""
    test_simple_training()
    test_transformer()
    test_lstm()
    test_lstm_seq()
    test_simple_verification()
    test_cifar_training()
    test_weight_perturbation()
    test_tinyimagenet()
    test_custom_op()
    test_efficient_convolution()
    test_two_node()

if __name__ == '__main__':
    test_release()
