#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import itertools
import os
import os.path

import sys
import logging
import torch
from copy import deepcopy

def cartesian_product(dicts):
    return (dict(zip(dicts, x)) for x in itertools.product(*dicts.values()))


def summary(configuration):
    kvs = sorted([(k, v) for k, v in configuration.items()], key=lambda e: e[0])
    return '_'.join([('%s=%s' % (k, v)) for (k, v) in kvs])


def to_cmd(c, _path=None):
    command = ''
    
    # python ats/compute_tranferability.py --model_target=dino_300ep_deitsmall16_finetune_deep_oxford_pets \
    # --model_proxy=dino_300ep_deitsmall16_finetune_deep_oxford_pets --data_name=cifar10 --attack_name=PGD \
    # --experiment_group=test --experiment_name=klir

    model_target = c["model_target"]

    model_name_target, tune_mode_target, tune_technique_target, dataset_name_target = parse_meta_from_model(model_target)
     
    experiment_group = f"blackbox_budget=10_{model_name_target}"
    experiment_name = f"model_target={dataset_name_target}_{tune_technique_target}_{tune_mode_target}_" \
                        f"data_name={dataset_name_target}_attack_name={c['attack_name']}"
    

    if c['attack_mode'] == 'targeted':
        experiment_name += '_targeted'
        targeted = True
    elif c['attack_mode'] == 'untargeted':
        experiment_name += '_untargeted'
        targeted = False

    command += f"python ats/blackbox_eval.py " \
                f"--model_target={model_target} " \
                f"--data_name={dataset_name_target} " \
                f"--attack_name={c['attack_name']} " \
                f"--experiment_group={experiment_group} " \
                f"--experiment_name={experiment_name} " + (f"--targeted" if targeted else "")
    
    return command


def to_logfile(c, path):

    new_c = deepcopy(c)

    model_target = c["model_target"]

    model_name_target, tune_mode_target, tune_technique_target, dataset_name_target = parse_meta_from_model(c['model_target'])

    del new_c['model_target']


    experiment_group = f"blackbox_budget=10_{model_name_target}"
    new_c['experiment_name'] = f"mt={dataset_name_target}_{tune_technique_target}_{tune_mode_target}_" \
                        f"dn={dataset_name_target}"
    
    outfile = "{}/{}/{}.log".format(path, experiment_group,summary(new_c).replace("/", "_"))
    
    if not os.path.exists(os.path.dirname(outfile)):
        os.makedirs(os.path.dirname(outfile))


    if len(outfile) > 250:
        print("Warning: path length exceeds 250 characters ({}). Shorten your path by using fewer directories.".format(len(outfile)))

        

    return outfile

def get_all_pretrained_models():
    model_paths = []
    model_names = []
    for root, dirs, files in os.walk("ats/models/"):
        for file in files:
            if file.endswith(".torch"):
                # print(os.path.join(root, file))
                model_paths.append(os.path.join(root, file))
                model_names.append(file.split(".")[0])


    return model_paths, model_names

def parse_meta_from_model(model_path):

    if 'oxford' in model_path:
        #last and second last are dataset name

        dataset_name = "_".join(model_path.split("_")[-2:])
        tune_technique = model_path.split("_")[-3]
        tune_mode = model_path.split("_")[-4]
        model_name = "_".join(model_path.split("_")[0:-4])
    else:
        dataset_name = model_path.split("_")[-1]
        tune_technique = model_path.split("_")[-2]
        tune_mode = model_path.split("_")[-3]
        model_name = "_".join(model_path.split("_")[0:-3])   
    
    return model_name, tune_mode, tune_technique, dataset_name


def get_model_families():
    model_families = []
    for model_path in os.listdir("ats/tuned_models/"):
        print(model_path)

        model_name, tune_mode, tune_technique, dataset_name = parse_meta_from_model(model_path)

        model_families.append(model_name)

    print(f"Unique model families length: {len(list(set(model_families)))}")
    print(f"Overall model families length: {len(model_families)}")

    return list(set(model_families))

def get_models():
    model_families = get_model_families()

    models = []
    for model_path in os.listdir("ats/tuned_models/"):
        model_name, tune_mode, tune_technique, dataset_name = parse_meta_from_model(model_path)
        models.append(model_path)


    return models


def main(argv):

    print("Starting NIPS qbit Blackbox v2")

    hyp_space = dict(        
        model_target=get_models(),
        attack_name=["Square"],    #, "Pixle"],
        attack_mode=["untargeted"],
        )

    configurations = list(cartesian_product(hyp_space))

    print("Number of configurations: {}".format(len(configurations)))
    # print("Configurations: {}".format(configurations))

    path = 'logs/nips_qbit_blackbox_budget=10_v2'
    is_rc = False

    # If the folder that will contain logs does not exist, create it
    if not os.path.exists(path):
        os.makedirs(path)
    else:
        is_rc = True

    command_lines = set()
    for cfg in configurations:
        logfile = to_logfile(cfg, path)

        completed = False
        if is_rc is True and os.path.isfile(logfile):
            with open(logfile, 'r', encoding='utf-8', errors='ignore') as f:
                content = f.read()
                completed = 'Shutting down' in content

        if not completed:
            command_line = '{} > {} 2>&1'.format(to_cmd(cfg), logfile)
            command_lines |= {command_line}


    assert len(list(set(command_lines))) == len(command_lines)
    sorted_command_lines = sorted(command_lines)

    import random
    rng = random.Random(0)
    rng.shuffle(sorted_command_lines)

    nb_jobs = len(sorted_command_lines)
    print('Number of jobs: {}'.format(nb_jobs))

    # Write the command_lines to a file in a separate folder
    save_dir = 'blackbox_budget=10_commands'
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    with open(f'{save_dir}/nips_qbit_blackbox_v2_commands.sh', 'w') as f:
        for command_line in sorted_command_lines:
            f.write(command_line + '\n')

    num_gpus = torch.cuda.device_count()

    from tqdm import tqdm

    sorted_command_lines = [sorted_command_lines[i:i + 4] for i in range(0, len(sorted_command_lines), 4)]

    for i, command_line in tqdm(enumerate(sorted_command_lines)):
        gpu_id = i % num_gpus

        for j in range(len(command_line)):
            with open(f'{save_dir}/nips_qbit_blackbox_v2_{gpu_id}_{j}.sh', 'a') as f:
                f.write('CUDA_LAUNCH_BLOCKING=1 DISABLE_TQDM=1 CUDA_VISIBLE_DEVICES={} {}'.format(gpu_id, command_line[j]) + '\n')

        
    # Jobs per GPU print
    for i in range(num_gpus):
        print('GPU {}: {} jobs'.format(i, len([x for x in sorted_command_lines if 'CUDA_VISIBLE_DEVICES={}'.format(i) in x])))
    
    # Approximate time per all runs. Every job takes 1/15 hour
    print('Approximate time for completion: {} hours'.format(len(sorted_command_lines)/15))


if __name__ == '__main__':
    logging.basicConfig(level=logging.INFO)
    main(sys.argv[1:])
