#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.

import os
import pathlib
import shutil

import submitit
import easydict
import random

from experiments import (
    algo_to_params,
    run_and_plot,
    LinearModelHparams,
    NNParams,
)

WORKING_DIRECTORY = "/path/to/working/directory"
METHODS = ["Eps_Greedy", "Greedy", "NeuralUCB", "PLOT", "AdvPLOT_1e_nor",
           "Adv_1e_nor"]
DATASETS = ["Adult", "Bank", "Crime", "German", "MNIST"]
SEED = 20_000
NUM_SEEDS = 5
VERSION_SUFFIX = ''
PROTECTED_CHARACTERISTIC = ''
PARALLEL = True
FAST = False
# Hardcoded
BATCH = 32
DECAY = 0.05
TIMESTEPS = 2500
NUM_EXPERIMENTS = 5
EXPERIMENT_RESULTS_DIR = 'experiment_results'

VERSION = f"_{TIMESTEPS}t_{'_'.join(METHODS)}_{'_'.join(DATASETS)}{VERSION_SUFFIX}{PROTECTED_CHARACTERISTIC}"
JOB_PREFIX = "exp_"
PARALLEL_STR = "_parallel" if PARALLEL else ""
JOB_NAME = f"{JOB_PREFIX}{PARALLEL_STR}_{VERSION}"
PARTITION = "short"
TIME = "12:00:00"
GPUS = 1


def copy_and_run_with_config(
        run_fn, run_config, directory, parallel=False, **cluster_config,
):
    print("Let's use slurm!")
    working_directory = pathlib.Path(directory) / cluster_config["job_name"]
    # TODO clean-up
    ignore_list = [
        "checkpoints",
        "experiments",
        "experiment_results",
        "experiment_results_no_intercept",
        "experiment_results_no_intercept_std",
        ".git",
        "output",
    ]
    shutil.copytree(".", working_directory, ignore=lambda x, y: ignore_list)
    os.chdir(working_directory)
    print(f"Running at {working_directory}")

    executor = submitit.SlurmExecutor(folder=working_directory)
    #     print(additional_params)
    #     executor.update_parameters(slurm_additional_parameters=additional_params)
    print(cluster_config)
    executor.update_parameters(**cluster_config)
    if parallel:
        print(run_config)
        jobs = executor.map_array(
            run_fn,
            # *args,
            *run_config
        )
        print(f"job_ids: {jobs}")
    else:
        job = executor.submit(run_fn, run_config)
        print(f"job_id: {job}")


def get_parallel_args(algos, datasets, seed):
    training_modes = ["full_minimization"] * (len(datasets) * len(algos))
    args = easydict.EasyDict({
        "ray": False,
        "num_experiments_per_machine": 1,
        "T": TIMESTEPS,
        "baseline_steps": 20_000,
        "batch_size": BATCH,
        "training_mode": "full_minimization",
    })
    nn_params = NNParams()
    nn_params.max_num_steps = args.T
    nn_params.batch_size = args.batch_size
    nn_params.baseline_steps = args.baseline_steps

    linear_model_hparams = LinearModelHparams()
    baselines = {}
    for dataset in datasets:
        # baselines[dataset]=train_baseline(dataset,nn_params,linear_model_hparams)
        baselines[dataset] = [None, None]

    baselines_list = [baselines[dataset] for dataset in datasets] * len(algos)
    seed_list = [seed] * (len(datasets) * len(algos))
    nn_param_list = [nn_params] * (len(datasets) * len(algos))
    linear_model_hparams = [linear_model_hparams] * (
                len(datasets) * len(algos))
    exploration_hparams = [algo_to_params(algo) for algo in algos]
    exploration_hparams = sum(
        [[param] * len(datasets) for param in exploration_hparams], [])
    num_experiments = [1] * (len(datasets) * len(algos))
    logging_frequency = [min(10, args.T // 5)] * (len(datasets) * len(algos))
    # TODO Ray is not working.
    ray = [False] * (len(datasets) * len(algos))
    algo_arr = sum([[algo] * len(datasets) for algo in algos], [])
    datasets_arr = datasets * len(algos)

    return [
        datasets_arr,
        training_modes,
        nn_param_list,
        linear_model_hparams,
        exploration_hparams,
        logging_frequency,
        num_experiments,
        ray,
        algo_arr,
        seed_list,
        baselines_list,

    ]


def run_experiments(num_seeds=1):
    seed = SEED
    random.seed(SEED)
    for i in range(num_seeds):
        args = get_parallel_args(METHODS, DATASETS, seed)

        for i in range(NUM_EXPERIMENTS):
            copy_and_run_with_config(
                run_and_plot,
                args,
                WORKING_DIRECTORY,
                parallel=PARALLEL,
                array_parallelism=32,
                job_name=JOB_NAME + '_' + '' + str(seed) + '_' + str(i),
                time=TIME,
                comment="None",
                partition=PARTITION,
                gres=f"gpu:{GPUS}",
            )
        seed = random.randint(1, 20_000)


if __name__ == '__main__':
    run_experiments(num_seeds=NUM_SEEDS)
