#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import itertools
import os
import os.path

import sys
import logging
import torch






def cartesian_product(dicts):
    return (dict(zip(dicts, x)) for x in itertools.product(*dicts.values()))


def summary(configuration):
    kvs = sorted([(k, v) for k, v in configuration.items()], key=lambda e: e[0])
    return '_'.join([('%s=%s' % (k, v)) for (k, v) in kvs])



'''
python vissl_tools/run_distributed_engines.py \
        hydra.verbose=true \
        config=eval_resnet_8gpu_transfer_cifar10_linear.yaml \
        +config/models/Oxford_pets/deit=dino_deit_s16 \
        config.DATA.TRAIN.DATA_SOURCES="[disk_folder]" \
        config.DATA.TRAIN.LABEL_SOURCES="[disk_folder]" \
        config.DATA.TRAIN.DATASET_NAMES="[Oxford_Pet]" \
        config.DATA.TRAIN.DATA_PATHS="[ats/data/datasets/Oxford_Pet]" \
        config.DATA.TEST.DATA_SOURCES="[disk_folder]" \
        config.DATA.TEST.LABEL_SOURCES="[disk_folder]" \
        config.DATA.TEST.DATASET_NAMES="[Oxford_Pet]" \
        config.DATA.TEST.DATA_PATHS="[ats/data/datasets/Oxford_Pet]" \
        config.TEST_MODEL=True \
        config.DATA.TEST.BATCHSIZE_PER_REPLICA=128 \
        config.DATA.TRAIN.BATCHSIZE_PER_REPLICA=128 \
        config.DISTRIBUTED.NUM_NODES=1 \
        config.DISTRIBUTED.NUM_PROC_PER_NODE=1 \
        config.CHECKPOINT.DIR="ats/tuned_models/proxy_finetune_dino_deit_s16_OxfordIIITPet_big_batch_new_config_1" \
        config.MODEL.WEIGHTS_INIT.PARAMS_FILE="ats/models/dino_300ep_deitsmall16/model_final_checkpoint_phase299.torch" \
        config.MODEL.WEIGHTS_INIT.STATE_DICT_KEY_NAME="classy_state_dict" \
        config.MODEL.WEIGHTS_INIT.APPEND_PREFIX="trunk.base_model._features" \
        config.MODEL.FEATURE_EVAL_SETTINGS.EVAL_MODE_ON=True \
        config.MODEL.FEATURE_EVAL_SETTINGS.FREEZE_TRUNK_ONLY=False \
        config.MODEL.FEATURE_EVAL_SETTINGS.FREEZE_TRUNK_AND_HEAD=False \
        config.MODEL.FEATURE_EVAL_SETTINGS.EXTRACT_TRUNK_FEATURES_ONLY=False \
        config.MODEL.FEATURE_EVAL_SETTINGS.EVAL_TRUNK_AND_HEAD=False \
        config.HOOKS.TENSORBOARD_SETUP.USE_TENSORBOARD=False
'''

'''
python vissl_tools/run_distributed_engines.py \
        hydra.verbose=true \
        config=eval_resnet_8gpu_transfer_cifar10_linear.yaml \
        config.DATA.TRAIN.DATA_SOURCES="[torchvision_dataset]" \
        config.DATA.TRAIN.LABEL_SOURCES="[torchvision_dataset]" \
        config.DATA.TRAIN.DATASET_NAMES="[CIFAR10]" \
        config.DATA.TRAIN.DATA_PATHS="[{dataset_path}]" \
        config.DATA.TEST.DATA_SOURCES="[torchvision_dataset]" \
        config.DATA.TEST.LABEL_SOURCES="[torchvision_dataset]" \
        config.DATA.TEST.DATASET_NAMES="[CIFAR10]" \
        config.DATA.TEST.DATA_PATHS="[{dataset_path}]" \
        config.TEST_MODEL=True \
        config.DATA.TEST.BATCHSIZE_PER_REPLICA=64 \
        config.DATA.TRAIN.BATCHSIZE_PER_REPLICA=64 \
        config.DISTRIBUTED.NUM_NODES=1 \
        config.DISTRIBUTED.NUM_PROC_PER_NODE=1 \
        config.CHECKPOINT.DIR="ats/tuned_models/model_target_fulltune_CIFAR10" \
        config.CHECKPOINT.OVERWRITE_EXISTING=False \
        config.MODEL.WEIGHTS_INIT.PARAMS_FILE="ats/models/resnet50-19c8e357.pth" \
        config.MODEL.WEIGHTS_INIT.STATE_DICT_KEY_NAME="" \
        config.MODEL.WEIGHTS_INIT.APPEND_PREFIX="trunk._feature_blocks." \
        config.MODEL.FEATURE_EVAL_SETTINGS.EVAL_MODE_ON=True \
        config.MODEL.FEATURE_EVAL_SETTINGS.FREEZE_TRUNK_ONLY=False \
        config.MODEL.FEATURE_EVAL_SETTINGS.FREEZE_TRUNK_AND_HEAD=False \
        config.MODEL.FEATURE_EVAL_SETTINGS.EXTRACT_TRUNK_FEATURES_ONLY=False \
        config.MODEL.FEATURE_EVAL_SETTINGS.EVAL_TRUNK_AND_HEAD=False \
        config.HOOKS.TENSORBOARD_SETUP.USE_TENSORBOARD=False \
        config.MODEL.HEAD.PARAMS="""[["eval_mlp", {"in_channels": 2048, "dims": [2048, 10]}]]""" \
        config.HOOKS.TENSORBOARD_SETUP.USE_TENSORBOARD=False \
'''


def to_cmd(c, _path=None):
    command = ''


    ## REMEMBER TO ADD TRUNK NAME FOR ALEXNET i.e alexnet_colorization, alexnet_deepcluster, alexnet_jigsaw etc.
    ## ALEXNET Append Prefix for FULLTUNE = trunk.
    ## ALEXNET Append Prefix for FINETUNE = trunk.base_model.
    ## ALEXNET State Dict Key Name = model_state_dict


    ## REMEMBER TO ADD STATE DICT KEY NAME FOR SOME RN MODELS i.e. resnet50_deepcluster, resnet50_colorization etc.
    ## RN Append Prefix for FULLTUNE = trunk.
    ## RN Append Prefix for FINETUNE = trunk.base_model.
    ## RN State Dict Key Name = model_state_dict

    ## converted_vissl_jigsaw_rn50_perm2k_in22k_8gpu_ep105 -> classy_state_dict

    ## FOR LOADING SWAV YOU CAN USE JIGSAW CONFIGS WITH classy_state_dict
    ## FOR LOADING SIMCLR YOU CAN USE JIGSAW CONFIGS WITH classy_state_dict
    ## FOR LOADING NPID++ YOU CAN USE JIGSAW CONFIGS WITH classy_state_dict
    ## FOR LOADING ROTNET YOU CAN USE JIGSAW CONFIGS WITH classy_state_dict

    ## FOR LOADING DEEPCLUSTER YOU CAN USE JIGSAW CONFIGS WITH empty string for state dict key name 
    ## and trunk._feature_blocks. (fulltuning) and trunk.base_model._feature_blocks. (finetuning) for append prefix
    ## and module. for remove prefix



    dataset_name = c['dataset_name']
    tune_mode = c['tune_mode']
    tune_technique = c['tune_technique']
    model_path = c['model_path']

    append_prefix = ""
    remove_prefix = ""
    main_config = ""
    extra_config = ""
    dataset_path = "ats/data/datasets/"
    key_name = ""

    model_name = model_path.split("/")[-2]

    if "alex" in model_path.lower():
        main_config = "eval_alexnet_8gpu_transfer_in1k_linear.yaml"

        if "colorization" in model_path.lower():
            remove_prefix = ""
            key_name = "model_state_dict"
            if "fulltune" in tune_mode:
                append_prefix = "trunk."
            elif "finetune" in tune_mode:
                append_prefix = "trunk.base_model."
            
            extra_config = f"+config/models/{dataset_name.lower()}/alexnet=alexnet_colorization_linear_{tune_technique}.yaml"
        
        elif "jigsaw" in model_path.lower():
            remove_prefix = ""
            key_name = "model_state_dict"
            if "fulltune" in tune_mode:
                append_prefix = "trunk."
            elif "finetune" in tune_mode:
                append_prefix = "trunk.base_model."
            
            extra_config = f"+config/models/{dataset_name.lower()}/alexnet=alexnet_jigsaw_linear_{tune_technique}.yaml"
        
    elif "resnet" in model_path.lower() or "rn50" in model_path.lower():
        main_config = "eval_resnet_8gpu_transfer_in1k_linear.yaml"
        if "deepcluster" in model_path.lower():
            remove_prefix = "module."
            key_name = ""
            if "fulltune" in tune_mode:
                append_prefix = "trunk._feature_blocks."
            elif "finetune" in tune_mode:
                append_prefix = "trunk.base_model._feature_blocks."
            
            extra_config = f"+config/models/{dataset_name.lower()}/resnet50=resnet_jigsaw_{tune_technique}.yaml"
        
        elif "colorization" in model_path.lower():
            remove_prefix = ""
            key_name = "model_state_dict"
            if "fulltune" in tune_mode:
                append_prefix = "trunk."
            elif "finetune" in tune_mode:
                append_prefix = "trunk.base_model."
            
            extra_config = f"+config/models/{dataset_name.lower()}/resnet50=resnet_colorization_{tune_technique}.yaml"
        
        elif "jigsaw" in model_path.lower():
            remove_prefix = ""
            if "perm2k" in model_path.lower():
                key_name = "classy_state_dict"
            else:
                key_name = "model_state_dict"

            if "fulltune" in tune_mode:
                append_prefix = "trunk."
            elif "finetune" in tune_mode:
                append_prefix = "trunk.base_model."
            
            extra_config = f"+config/models/{dataset_name.lower()}/resnet50=resnet_jigsaw_{tune_technique}.yaml"
        
        elif "swav" in model_path.lower() or "simclr" in model_path.lower() or \
            "npid" in model_path.lower() or "rotnet" in model_path.lower():
            remove_prefix = ""
            key_name = "classy_state_dict"
            if "fulltune" in tune_mode:
                append_prefix = "trunk."
            elif "finetune" in tune_mode:
                append_prefix = "trunk.base_model."
            
            extra_config = f"+config/models/{dataset_name.lower()}/resnet50=resnet_jigsaw_{tune_technique}.yaml"

    elif "npid" in model_path.lower():
        main_config = "eval_resnet_8gpu_transfer_in1k_linear.yaml"
        remove_prefix = ""
        key_name = "classy_state_dict"
        if "fulltune" in tune_mode:
            append_prefix = "trunk."
        elif "finetune" in tune_mode:
            append_prefix = "trunk.base_model."
        
        extra_config = f"+config/models/{dataset_name.lower()}/resnet50=resnet_jigsaw_{tune_technique}.yaml"

    elif "deit" in model_path.lower():
        main_config = "eval_resnet_8gpu_transfer_in1k_linear.yaml"
        remove_prefix = ""
        key_name = "classy_state_dict"
        if "fulltune" in tune_mode:
            append_prefix = "trunk."
        elif "finetune" in tune_mode:
            append_prefix = "trunk.base_model."
        
        extra_config = f"+config/models/{dataset_name.lower()}/deit=dino_deit_s16_{tune_technique}.yaml"
    
    elif "xcit" in model_path.lower():
        main_config = "eval_resnet_8gpu_transfer_in1k_linear.yaml"
        remove_prefix = ""
        key_name = "classy_state_dict"
        if "fulltune" in tune_mode:
            append_prefix = "trunk."
        elif "finetune" in tune_mode:
            append_prefix = "trunk.base_model."

        extra_config = f"+config/models/{dataset_name.lower()}/xcit=dino_xcit_s16_{tune_technique}.yaml"


    if 'pirl_jigsaw_4node' in model_path.lower():
        key_name = "classy_state_dict"

    if 'oxford' in dataset_name.lower():
        dataset_mode = "disk_folder"
    else:
        dataset_mode = "torchvision_dataset"

        

    if tune_mode == 'finetune':
        freeze_trunk = "True"
    elif tune_mode == 'fulltune':
        freeze_trunk = "False"


    if 'oxford_flowers' in dataset_name.lower():
        dataset_path += "oxford_flowers/"
    elif 'oxford_pets' in dataset_name.lower():
        dataset_path += "oxford_pets/"

    model_save_name = f"{model_name}_{tune_mode}_{tune_technique}_{dataset_name}"

    command = f'DISABLE_TQDM=1 python vissl_tools/run_distributed_engines.py ' \
            f'hydra.verbose=true config={main_config} ' \
            f'{extra_config} ' \
            f'config.DATA.TRAIN.DATA_SOURCES="[{dataset_mode}]" config.DATA.TRAIN.LABEL_SOURCES="[{dataset_mode}]" config.DATA.TRAIN.DATASET_NAMES="[{dataset_name}]" ' \
            f'config.DATA.TRAIN.DATA_PATHS="[{dataset_path}]" config.DATA.TEST.DATA_PATHS="[{dataset_path}]" ' \
            f'config.DATA.TEST.DATA_SOURCES="[{dataset_mode}]" config.DATA.TEST.LABEL_SOURCES="[{dataset_mode}]" config.DATA.TEST.DATASET_NAMES="[{dataset_name}]" ' \
            f'config.TEST_MODEL=True config.DATA.TEST.BATCHSIZE_PER_REPLICA=64 config.DATA.TRAIN.BATCHSIZE_PER_REPLICA=64 ' \
            f'config.DISTRIBUTED.NUM_NODES=1 config.DISTRIBUTED.NUM_PROC_PER_NODE=1 ' \
            f'config.CHECKPOINT.DIR="ats/tuned_models/{model_name}_{tune_mode}_{tune_technique}_{dataset_name}" ' \
            f'config.CHECKPOINT.OVERWRITE_EXISTING=False config.MODEL.WEIGHTS_INIT.PARAMS_FILE="{model_path}" ' \
            f'config.MODEL.WEIGHTS_INIT.STATE_DICT_KEY_NAME="{key_name}" config.MODEL.WEIGHTS_INIT.APPEND_PREFIX="{append_prefix}" config.MODEL.WEIGHTS_INIT.REMOVE_PREFIX="{remove_prefix}" ' \
            f'config.MODEL.FEATURE_EVAL_SETTINGS.EVAL_MODE_ON=True config.MODEL.FEATURE_EVAL_SETTINGS.FREEZE_TRUNK_ONLY={freeze_trunk} ' \
            f'config.MODEL.FEATURE_EVAL_SETTINGS.FREEZE_TRUNK_AND_HEAD=False config.MODEL.FEATURE_EVAL_SETTINGS.EXTRACT_TRUNK_FEATURES_ONLY=False ' \
            f'config.MODEL.FEATURE_EVAL_SETTINGS.EVAL_TRUNK_AND_HEAD=False config.HOOKS.TENSORBOARD_SETUP.USE_TENSORBOARD=False ' \
            f'config.CHECKPOINT.OVERWRITE_EXISTING=False'
            


    return command


def to_logfile(c, path):
    outfile = "{}/NIPS_qbit_v2.{}.log".format(path, summary(c).replace("/", "_"))
    return outfile

def get_all_pretrained_models():
    model_paths = []
    model_names = []
    for root, dirs, files in os.walk("ats/models/"):
        for file in files:
            if file.endswith(".torch"):
                print(os.path.join(root, file))
                model_paths.append(os.path.join(root, file))
                model_names.append(file.split(".")[0])


    return model_paths, model_names


def main(argv):

    print("Starting NIPS qbit v1")

    hyp_space = dict(        
        dataset_name=["CIFAR10", "CIFAR100", "oxford_flowers", "oxford_pets"],
        model_path=get_all_pretrained_models()[0],
        tune_technique=["shallow", "deep"],
        tune_mode = ["fulltune", "finetune"],
        )

    configurations = list(cartesian_product(hyp_space))

    print("Number of configurations: {}".format(len(configurations)))
    # print("Configurations: {}".format(configurations))

    path = 'logs/nips_qbit_v1'
    is_rc = False

    # If the folder that will contain logs does not exist, create it
    if not os.path.exists(path):
        os.makedirs(path)
    else:
        is_rc = True

    command_lines = set()
    for cfg in configurations:
        logfile = to_logfile(cfg, path)

        completed = False
        if is_rc is True and os.path.isfile(logfile):
            with open(logfile, 'r', encoding='utf-8', errors='ignore') as f:
                content = f.read()
                completed = 'Training finished' in content

        if not completed:
            command_line = '{} > {} 2>&1'.format(to_cmd(cfg), logfile)
            command_lines |= {command_line}

    sorted_command_lines = sorted(command_lines)

    import random
    rng = random.Random(0)
    rng.shuffle(sorted_command_lines)

    nb_jobs = len(sorted_command_lines)
    print('Number of jobs: {}'.format(nb_jobs))


    # Write the command_lines to a file
    with open('nips_qbit_v1.sh', 'w') as f:
        for command_line in sorted_command_lines:
            f.write(command_line + '\n')
    

    num_gpus = torch.cuda.device_count()

    err_runs = True
    if err_runs:
     for i, command_line in enumerate(sorted_command_lines):
            if 'ats_models_converted_vissl_rn50_colorization_in1k_goyal19_converted_vissl_rn50_colorization_in1k_goyal19' in command_line \
                or "ats_models_converted_vissl_rn50_colorization_in22k_goyal19_converted_vissl_rn50_colorization_in22k_goyal19" in command_line \
                or "ats_models_dino_300ep_deitsmall16_model_final_checkpoint_phase299" in command_line \
                or "ats_models_pirl_jigsaw_4node_pirl_jigsaw_4node_resnet_22_07_20_model_final_checkpoint_phase799" in command_line:

                with open('nips_qbit_v1_error_gpu{}.sh'.format(i % num_gpus), 'a') as f:
                    f.write('CUDA_VISIBLE_DEVICES={} {}'.format(i % num_gpus, command_line) + '\n')

    else: 
        for i, command_line in enumerate(sorted_command_lines):
            with open('nips_qbit_v1_gpu{}.sh'.format(i % num_gpus), 'a') as f:
                f.write('CUDA_VISIBLE_DEVICES={} {}'.format(i % num_gpus, command_line) + '\n')

    # Jobs per GPU print
    for i in range(num_gpus):
        print('GPU {}: {} jobs'.format(i, len([x for x in sorted_command_lines if 'CUDA_VISIBLE_DEVICES={}'.format(i) in x])))
    
    # Approximate time per all runs. Every job takes 1 hour
    print('Approximate time for completion: {} hours'.format(len(sorted_command_lines)))


if __name__ == '__main__':
    logging.basicConfig(level=logging.INFO)
    main(sys.argv[1:])
