import json
import hashlib
import os
import copy
import shlex
import numpy as np
import tqdm
import shutil
import argparse

from advbench.lib import misc
from advbench import algorithms
from advbench import datasets
from advbench import command_launchers
from advbench import hparams_registry

class Job:
    NOT_LAUNCHED = 'Not launched'
    INCOMPLETE = 'Incomplete'
    DONE = 'Done'

    def __init__(self, train_args, sweep_output_dir):
        args_str = json.dumps(train_args, sort_keys=True)
        args_hash = hashlib.md5(args_str.encode('utf-8')).hexdigest()
        self.output_dir = os.path.join(sweep_output_dir, args_hash)

        self.train_args = copy.deepcopy(train_args)
        self.train_args['output_dir'] = self.output_dir
        command = ['python', '-m', 'advbench.scripts.train']

        for k, v in sorted(self.train_args.items()):
            if isinstance(v, list):
                v = ' '.join([str(v_) for v_ in v])
            elif isinstance(v, str):
                v = shlex.quote(v)
            
            command.append(f'--{k} {v}')
        self.command_str = ' '.join(command)

        if os.path.exists(os.path.join(self.output_dir, 'done')):
            self.state = Job.DONE
        elif os.path.exists(os.path.join(self.output_dir)):
            self.state = Job.INCOMPLETE
        else:
            self.state = Job.NOT_LAUNCHED

    def __str__(self):
        job_info = (
            self.train_args['dataset'],
            self.train_args['algorithm'],
            self.train_args['hparams_seed']
        )
        return f'{self.state}: {self.output_dir} {job_info}'

    @staticmethod
    def launch(jobs, launcher_fn):
        print('Launching...')
        jobs = jobs.copy()
        np.random.shuffle(jobs)
        print('Making job directories:')
        for job in tqdm.tqdm(jobs, leave=False):
            os.makedirs(job.output_dir, exist_ok=True)
        commands = [job.command_str for job in jobs]
        launcher_fn(commands)
        print(f'Launched {len(jobs)} jobs!')


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Run a sweep')
    parser.add_argument('--data_dir', type=str, default='./advbench/data')
    parser.add_argument('--output_dir', type=str, default='robustbench_eval')
    parser.add_argument('--beta_n_steps', nargs='+', type=int, default=[1])
    parser.add_argument('--beta_lrs', type=float, nargs='+', default=[1e-2])
    parser.add_argument('--batch_size', type=int, default=16)
    args = parser.parse_args()

    for n_steps in args.beta_n_steps:
        for lr in args.beta_lrs:
            command = ['python', '-m', 'advbench.scripts.robustbench']
            command.append(f'--output_dir {args.output_dir}')
            command.append(f'--data_dir {args.data_dir}')
            command.append(f'--beta_lr {lr}')
            command.append(f'--beta_n_steps {n_steps}')
            command.append(f'--batch_size {args.batch_size}')
            command_str = ' '.join(command)
            print(command_str)