import itertools
import submitit

from pdPINN.model.model_utiities import ST
from pathlib import Path
from tqdm import tqdm
from collections import Counter
from datetime import datetime
from experiment_2d import main
import time
import numpy as np
from pdPINN.util.system_util import clean_up_jobs


def build_settings(sampling_method,
                   seed,
                   n_samples_constraints,
                   rar,
                   ot_rar,
                   pde_weight,
                   run_num):
    PROJECT_PATH = Path(__file__).parent

    DATA_DIR = PROJECT_PATH / "data"
    DATA_2D = DATA_DIR / "data_2d_fixed_val"
    # DATA_2D_thr = DATA_DIR / "data_2d_thr"
    # RESULT_CSV = PROJECT_PATH / "notebooks/experiments_evaluation/results_2d.csv"

    path_settings = dict(
        df_train=DATA_2D / "train_2d.snappy",
        df_val=DATA_2D / "eval_2d.snappy",
        df_test=DATA_2D / "test_2d.snappy",

        results_csv=PROJECT_PATH / "notebooks/experiments_evaluation/sweep_exp_2d_sep22.csv",
    )

    # path_settings = {key: val.resolve() for key, val in path_settings.items()}

    model_settings = dict(
        experiment="sweep_exp_2d_sep22",
        sampling_method=sampling_method,
        seed=seed,
        n_samples_constraints=n_samples_constraints,

        mass_only=False,

        hidden_features=256,
        vf_hidden_features=64,
        density_hidden_layers=2,
        vf_hidden_layers=1,

        num_its=500,
        nonlinearity="sine",
        sine_frequency=12,

        # mcmc=False,
        # mcmc_noiselevel=args.mcmc_noiselevel,  # 1e-5,

        tmin=-.5, tmax=3.5,
        xmin=-4, xmax=4,
        ymin=-4, ymax=4,

        rar=rar,

        # device="cpu",
        device="cuda",
        pde_weight=pde_weight,
        ot_rar=ot_rar,

        no_plots=True,
        run_num=run_num
    )
    return model_settings, path_settings


if __name__ == "__main__":
    rars = [False, True]
    ot_rars = [False, True]
    n_samples_l = [int(2 ** i) for i in range(4, 14)]
    # n_samples_l = [12, 13]  # [int(2 ** i) for i in range(12, 14)]
    sampling_methods = [ST.uniform, ST.importance_sampling, ST.it_pdpinn, ST.mh_pdpinn]
    seeds = list(range(10))
    cartesian_product = itertools.product(sampling_methods, n_samples_l, rars, ot_rars)

    start_weight = {ST.uniform: 1e-7,
                    ST.importance_sampling: 1e-6,
                    ST.it_pdpinn: 1e-7,
                    ST.mh_pdpinn: 1e-7
                    }
    end_weight = {ST.uniform: 1e-5,
                  ST.importance_sampling: 1e-3,
                  ST.it_pdpinn: 1e-5,
                  ST.mh_pdpinn: 1e-5
                  }

    # The specified folder is used to dump job information, logs and result when finished
    # %j is replaced by the job id at runtime
    log_folder = "sweep_exp_2d_new/%j"

    executor = submitit.AutoExecutor(folder=log_folder)
    executor.update_parameters(slurm_array_parallelism=40,
                               slurm_mem_per_cpu="2G",
                               slurm_cpus_per_task=15,
                               slurm_time="06:00:00",  # "24:00:00",
                               slurm_qos="6hours",
                               # slurm_partition="pascal",
                               # slurm_additional_parameters=dict(
                               #     gres="gpu:1"
                               # )
                               )

    jobs = []
    with executor.batch():

        for seed in seeds:
            model_settings, path_settings = build_settings(sampling_method=ST.none,
                                                           seed=seed, n_samples_constraints=1,
                                                           rar=False, ot_rar=False, pde_weight=0., run_num=len(jobs))
            # loss_dict = main(model_settings, path_settings)
            job = executor.submit(main, model_settings, path_settings)
            jobs.append(job)

        for sampling_method, n_samples, rar, ot_rar in cartesian_product:
            if rar and ot_rar:
                continue
            if rar and sampling_method not in [ST.uniform]:
                continue
            if ot_rar and sampling_method not in [ST.uniform]:
                continue

            pde_weights = np.linspace(start_weight[sampling_method], end_weight[sampling_method], 7)
            print(
                f"{sampling_method} - {n_samples} - "
                f"{start_weight[sampling_method]:.2e} to {end_weight[sampling_method]:.2e} "
                f"over {len(pde_weights)} steps.")

            for pde_weight in pde_weights:
                for seed in seeds:
                    model_settings, path_settings = build_settings(sampling_method=sampling_method,
                                                                   seed=seed,
                                                                   n_samples_constraints=n_samples,
                                                                   rar=rar,
                                                                   ot_rar=ot_rar,
                                                                   pde_weight=pde_weight,
                                                                   run_num=len(jobs))
                    job = executor.submit(main, model_settings, path_settings)
                    jobs.append(job)
        print(len(jobs))

    clean_up_jobs(jobs, f"sweep_exp_2d_{datetime.now().strftime('%Y_%m_%d_%H_%M_%S')}")
