#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import itertools
import os
import os.path

import sys
import logging
import torch
from copy import deepcopy

def cartesian_product(dicts):
    return (dict(zip(dicts, x)) for x in itertools.product(*dicts.values()))


def summary(configuration):
    kvs = sorted([(k, v) for k, v in configuration.items()], key=lambda e: e[0])
    return '_'.join([('%s=%s' % (k, v)) for (k, v) in kvs])


def to_cmd(c, _path=None):
    command = ''
    
    # python ats/compute_tranferability.py --model_target=dino_300ep_deitsmall16_finetune_deep_oxford_pets \
    # --model_proxy=dino_300ep_deitsmall16_finetune_deep_oxford_pets --data_name=cifar10 --attack_name=PGD \
    # --experiment_group=test --experiment_name=klir

    model_target, model_proxy = c["model_pairs"]

    model_name_target, tune_mode_target, tune_technique_target, dataset_name_target = parse_meta_from_model(model_target)
    model_name_proxy, tune_mode_proxy, tune_technique_proxy, dataset_name_proxy = parse_meta_from_model(model_proxy)
     
    # assert model_name_target == model_name_proxy
    experiment_group = f"transferability_backbone_cross_model_target={model_name_target}"
    experiment_name = f"model_target={dataset_name_target}_{tune_technique_target}_{tune_mode_target}_" \
                        f"model_proxy={dataset_name_proxy}_{tune_technique_proxy}_{tune_mode_proxy}_" \
                        f"data_name={dataset_name_target}_attack_name={c['attack_name']}"
    

    if c['attack_mode'] == 'targeted':
        experiment_name += '_targeted'
        targeted = True
    elif c['attack_mode'] == 'untargeted':
        experiment_name += '_untargeted'
        targeted = False

    command += f"python ats/compute_tranferability.py " \
                f"--model_target={model_target} " \
                f"--model_proxy={model_proxy} " \
                f"--data_name={dataset_name_target} " \
                f"--attack_name={c['attack_name']} " \
                f"--experiment_group={experiment_group} " \
                f"--experiment_name={experiment_name} " + (f"--targeted" if targeted else "")
    
    return command


def to_logfile(c, path):
    new_c = deepcopy(c)
    new_c['target_model'] = new_c['model_pairs'][0]
    new_c['proxy_model'] = new_c['model_pairs'][1]

    model_name_target, tune_mode_target, tune_technique_target, dataset_name_target = parse_meta_from_model(new_c['target_model'])
    model_name_proxy, tune_mode_proxy, tune_technique_proxy, dataset_name_proxy = parse_meta_from_model(new_c['proxy_model'])
    del new_c['model_pairs']
    del new_c['target_model']
    del new_c['proxy_model']

    experiment_group = f"g_{model_name_target}"
    new_c['experiment_name'] = f"mt={dataset_name_target}_{tune_technique_target}_{tune_mode_target}_" \
                        f"mp={dataset_name_proxy}_{tune_technique_proxy}_{tune_mode_proxy}"
    
    outfile = "{}/{}/{}.log".format(path, experiment_group,summary(new_c).replace("/", "_"))
    
    if not os.path.exists(os.path.dirname(outfile)):
        os.makedirs(os.path.dirname(outfile))


    if len(outfile) > 250:
        print("Warning: path length exceeds 250 characters ({}). Shorten your path by using fewer directories.".format(len(outfile)))
        outfile = outfile[:250]
        

    return outfile

def get_all_pretrained_models():
    model_paths = []
    model_names = []
    for root, dirs, files in os.walk("ats/models/"):
        for file in files:
            if file.endswith(".torch"):
                # print(os.path.join(root, file))
                model_paths.append(os.path.join(root, file))
                model_names.append(file.split(".")[0])


    return model_paths, model_names

def parse_meta_from_model(model_path):

    if 'oxford' in model_path:
        #last and second last are dataset name

        dataset_name = "_".join(model_path.split("_")[-2:])
        tune_technique = model_path.split("_")[-3]
        tune_mode = model_path.split("_")[-4]
        model_name = "_".join(model_path.split("_")[0:-4])
    else:
        dataset_name = model_path.split("_")[-1]
        tune_technique = model_path.split("_")[-2]
        tune_mode = model_path.split("_")[-3]
        model_name = "_".join(model_path.split("_")[0:-3])   
    
    return model_name, tune_mode, tune_technique, dataset_name


def get_model_families():
    model_families = []
    for model_path in os.listdir("ats/tuned_models/"):
        # print(model_path)

        model_name, tune_mode, tune_technique, dataset_name = parse_meta_from_model(model_path)

        model_families.append(model_name)


    cross_flag = False
    if cross_flag:
        m1 = "swav_4gpu_bs64_400ep_2x224_6x96_queue_swav_8node_resnet_28_07_20"
        m2 = "swav_in1k_rn50_800ep_swav_8node_resnet_27_07_20"
        model_families = [x for x in model_families if x == m1 or x == m2]

    print(f"Unique model families length: {len(list(set(model_families)))}")
    print(f"Overall model families length: {len(model_families)}")

    return list(set(model_families))

def get_backbone_attack_model_pairs(family_dict, model_families, filter_flag=False):
    backbone_attack_model_pairs = []

    if filter_flag:
        chosen_backbone_models = {}
        for model_name in model_families:
            for model_path in family_dict[model_name]:
                if 'finetune' in model_path:
                    chosen_backbone_models[model_name] = model_path
                    break
        
        #cross family pairs
        for chosen_backbone_model_family_m1 in chosen_backbone_models:
            for chosen_backbone_model_family_m2 in chosen_backbone_models:
                if chosen_backbone_model_family_m1 == chosen_backbone_model_family_m2:
                    continue

                for model_path in family_dict[chosen_backbone_model_family_m1]:
                    backbone_attack_model_pairs.append([model_path, chosen_backbone_models[chosen_backbone_model_family_m2]])
                
    else:

        for model_name in model_families:

            chosen_backbone_model = None
            for model_path in family_dict[model_name]:
                if chosen_backbone_model is None:
                    if 'finetune' in model_path:
                        chosen_backbone_model = model_path
                        break
            assert chosen_backbone_model is not None

            for model_path in family_dict[model_name]:
                if model_name in model_path:
                    backbone_attack_model_pairs.append([model_path, chosen_backbone_model])

        
    return backbone_attack_model_pairs

def get_pairwise_model_families():
    model_families = get_model_families()

    family_dict = { family:[] for family in model_families}
    for model_path in os.listdir("ats/tuned_models/"):
        model_name, tune_mode, tune_technique, dataset_name = parse_meta_from_model(model_path)

        if model_name in model_families:
            family_dict[model_name].append(model_path)
    
    #for each family, get all possible pairs
    family_pairs = []
    for family in family_dict:
        family_pairs.extend(list(itertools.permutations(family_dict[family], 2)))

    backbone_attack_flag = False
    cross_flag = False

    if backbone_attack_flag:
            family_pairs = get_backbone_attack_model_pairs(family_dict, model_families, filter_flag=cross_flag)


    # get the pairs cross different families

    if cross_flag and not backbone_attack_flag:
        cross_family_pairs = []

        family1 = "swav_4gpu_bs64_400ep_2x224_6x96_queue_swav_8node_resnet_28_07_20"
        family2 = "swav_in1k_rn50_800ep_swav_8node_resnet_27_07_20"

        cross_family_pairs.extend(list(itertools.product(family_dict[family1], family_dict[family2])))
        
        print(f"Number of cross family pairs: {len(cross_family_pairs)}")

        family_pairs = cross_family_pairs

    return family_pairs


        

def main(argv):

    print("Starting NIPS qbit Cross Family Compute Transfer v2")

    hyp_space = dict(        
        model_pairs=get_pairwise_model_families(),
        attack_name=["PGD"],
        attack_mode=["targeted","untargeted"],
        )

    configurations = list(cartesian_product(hyp_space))

    print("Number of configurations: {}".format(len(configurations)))
    # print("Configurations: {}".format(configurations))

    path = 'logs/nips_qbit_longer_attack_transfer_v2'
    is_rc = False

    # If the folder that will contain logs does not exist, create it
    if not os.path.exists(path):
        os.makedirs(path)
    else:
        is_rc = True

    command_lines = set()
    for cfg in configurations:
        logfile = to_logfile(cfg, path)

        completed = False
        if is_rc is True and os.path.isfile(logfile):
            with open(logfile, 'r', encoding='utf-8', errors='ignore') as f:
                content = f.read()
                completed = 'blackbox experiment finished' in content
                print(f"the Logfile is completed: {completed}")
        if not completed:
            command_line = '{} > {} 2>&1'.format(to_cmd(cfg), logfile)
            command_lines |= {command_line}


    assert len(list(set(command_lines))) == len(command_lines)
    sorted_command_lines = sorted(command_lines)

    import random
    rng = random.Random(0)
    rng.shuffle(sorted_command_lines)

    nb_jobs = len(sorted_command_lines)
    print('Number of jobs: {}'.format(nb_jobs))

    # Write the command_lines to a file in a separate folder
    # save_dir = 'cross_transferability_commands_new'
    save_dir = 'transferability_commands_longer_attack'
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    with open(f'{save_dir}/iclr_dgx_longer_attack_transfer_v1_commands.sh', 'w') as f:
        for command_line in sorted_command_lines:
            f.write(command_line + '\n')

    num_gpus = 16 #torch.cuda.device_count()

    from tqdm import tqdm

    sorted_command_lines = [sorted_command_lines[i:i + 4] for i in range(0, len(sorted_command_lines), 4)]

    alternative_counter = 0
    for i, command_line in tqdm(enumerate(sorted_command_lines)):
        gpu_id = i % num_gpus

        for j in range(len(command_line)):

            with open(f'{save_dir}/iclr_dgx_backbone_cross_transfer_v1_{gpu_id}_{j}.sh', 'a') as f:
                if gpu_id > 6:
                    f.write('CUDA_LAUNCH_BLOCKING=1 DISABLE_TQDM=1 CUDA_VISIBLE_DEVICES={} {}'.format(gpu_id-7, command_line[j]) + '\n')
                else:
                    f.write('CUDA_LAUNCH_BLOCKING=1 DISABLE_TQDM=1 CUDA_VISIBLE_DEVICES={} {}'.format(gpu_id, command_line[j]) + '\n')

        
    # Jobs per GPU print
    for i in range(num_gpus):
        print('GPU {}: {} jobs'.format(i, len([x for x in sorted_command_lines if 'CUDA_VISIBLE_DEVICES={}'.format(i) in x])))
    
    # Approximate time per all runs. Every job takes 1/15 hour
    print('Approximate time for completion: {} hours'.format(len(sorted_command_lines)/15))


if __name__ == '__main__':
    logging.basicConfig(level=logging.INFO)
    main(sys.argv[1:])
