import pprint

import submitit

import itertools
from datetime import datetime

from experiment_3d import main, ST, Path
from pdPINN.util.system_util import clean_up_jobs
import argparse
import pandas as pd

import numpy as np


def build_settings(sampling_method,
                   seed,
                   n_samples_constraints,
                   rar,
                   ot_rar,
                   pde_weight,
                   run_num):
    PROJECT_PATH = Path(__file__).parent

    DATA_DIR = PROJECT_PATH / "data"
    DATA_2D = DATA_DIR / "data_3d_fixed_val"
    # DATA_2D_thr = DATA_DIR / "data_2d_thr"
    # RESULT_CSV = PROJECT_PATH / "notebooks/experiments_evaluation/results_2d.csv"

    path_settings = dict(
        df_train=DATA_2D / "train_3d.snappy",
        df_val=DATA_2D / "eval_3d.snappy",
        df_test=DATA_2D / "test_3d.snappy",

        results_csv=PROJECT_PATH / "notebooks/experiments_evaluation/sweep_exp3d_sep22.csv",
    )

    path_settings = {key: val.resolve() for key, val in path_settings.items()}

    if args.low_mem:
        low_mem = "_lowmem"
    else:
        low_mem = "_highmem"

    model_settings = dict(
        experiment=f"sweep_exp_3d_sep24_moreseeds{low_mem}",  # hparam_region_pow10,
        sampling_method=sampling_method,
        seed=seed,
        n_samples_constraints=n_samples_constraints,

        mass_only=False,

        hidden_features=256,
        vf_hidden_features=256,
        density_hidden_layers=6,
        vf_hidden_layers=3,

        num_its=300,
        nonlinearity="sine",
        sine_frequency=5,
        lr=1e-4,

        # mcmc=False,
        # mcmc_noiselevel=args.mcmc_noiselevel,  # 1e-5,

        tmin=-.5, tmax=3.5,
        xmin=-4, xmax=4,
        ymin=-4, ymax=4,
        zmin=-.125, zmax=1.0,

        rar=rar,
        ot_rar=ot_rar,

        device="cuda",
        pde_weight=pde_weight,
        silent=False,

        no_plots=True,
        dataset=str(path_settings["df_train"].parent),

        include_boundary_coditions=False,
        weight_velocity_loss=4,
        normalize_loss=True,
        run_num=run_num
    )
    return model_settings, path_settings


def check_condition(rar, ot_rar, sampling_method, n_samples):
    skip_next = False
    if rar and ot_rar:
        skip_next = True
    if rar and sampling_method not in [ST.uniform]:
        skip_next = True
    if ot_rar and sampling_method not in [ST.uniform]:
        skip_next = True

    # if n_samples > 60_000 and sampling_method not in [ST.uniform]:
    #     skip_next = True

    return skip_next

def check_rerun(rar, ot_rar, sampling_method, n_samples):
    skip_next = False

    # if (sampling_method is ST.uniform) or (sampling_method is ST.none):
    #     skip_next= True

    requires_large_mem = ((n_samples > 33_000) or ((n_samples > 5000) and (sampling_method is ST.importance_sampling)))
    if args.low_mem and requires_large_mem:
        skip_next = True
        # if not (n_samples > 33_000 or ((n_samples > 5000) and (sampling_method is ST.importance_sampling))):
            # skip_next= True
    
    if (not args.low_mem) and (not requires_large_mem):
        skip_next = True

    # if sampling_method is ST.uniform and rar and n_samples > 60000:
    #     skip_next = False
    # else:
    #     skip_next = True
    #
    # if sampling_method in [ST.uniform,
    #                        ST.none,
    #                        ST.mh_pdpinn]:
    #     skip_next = True
    #
    # if sampling_method in [ST.it_pdpinn] and n_samples < 60000:
    #     skip_next = True
    #
    # if sampling_method in [ST.importance_sampling] and n_samples < 8000:
    #     skip_next = True

    return skip_next


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--cluster", default=None, type=str)
    parser.add_argument("--rerun-failed", action="store_true")
    parser.add_argument("--check", action="store_true")
    parser.add_argument("--low-mem", action="store_true")
    args = parser.parse_args()
    if args.rerun_failed:
        failed_runs = pd.read_csv("exp3d_jobstats_output.txt", sep="|", header=0).query('State=="FAILED"')
        failed_runs["run_num"] = failed_runs.JobID.str.split("_", 1, expand=True)[1]
    else:
        failed_runs = None



    rars = [False, True]
    ot_rars = [False, True]

    n_samples_list = [int(2 ** i) for i in range(8, 17)]
    # n_samples_high = [int(2 ** i) for i in range(12, 17)]

    sampling_methods = [ST.uniform, ST.importance_sampling, ST.it_pdpinn, ST.mh_pdpinn]
    # seeds = list(range(5))
    seeds = list(range(5, 10))
    cartesian_product = itertools.product(sampling_methods, n_samples_list, rars, ot_rars)

    start_weight = {ST.uniform: 50,
                    # ST.importance_sampling: 50,
                    # ST.it_pdpinn: 50,
                    # ST.mh_pdpinn: 50
                    ST.importance_sampling: 100,
                    ST.it_pdpinn: 100,
                    ST.mh_pdpinn: 100
                    }

    end_weight = {ST.uniform: 3000,
                  # ST.importance_sampling: 3000,
                  # ST.it_pdpinn: 3000,
                  # ST.mh_pdpinn: 3000
                  ST.importance_sampling: 900,
                  ST.it_pdpinn: 900,
                  ST.mh_pdpinn: 900
                  }

    

    pde_weight_list = {ST.uniform: np.linspace(50, 3000, 7),
                  # ST.importance_sampling: 3000,
                  # ST.it_pdpinn: 3000,
                  # ST.mh_pdpinn: 3000
                  ST.importance_sampling: np.sort(np.concatenate((np.linspace(50, 3000, 7)[:2], np.linspace(100, 900, 4)))),
                  ST.it_pdpinn: np.sort(np.concatenate((np.linspace(50, 3000, 7)[:2], np.linspace(100, 900, 4)))),
                  ST.mh_pdpinn: np.sort(np.concatenate((np.linspace(50, 3000, 7)[:2], np.linspace(100, 900, 4))))
                  }

    
    # The specified folder is used to dump job information, logs and result when finished
    # %j is replaced by the job id at runtime
    log_folder = "sweep_exp_3d_new/%j"

    executor = submitit.AutoExecutor(cluster=args.cluster, folder=log_folder)

    if args.low_mem:
        GPU_PARTITION = "pascal"
        executor.update_parameters(slurm_array_parallelism=16,
                                slurm_mem_per_cpu="3GB",
                                slurm_cpus_per_task=8,
                                slurm_time="06:00:00",  # "24:00:00",
                                slurm_qos="6hours",
                                slurm_partition=GPU_PARTITION, #"pascal"
                                slurm_additional_parameters=dict(
                                    gres="gpu:1"
                                )
                                )
                            
    else:
        GPU_PARTITION = "rtx8000"
        executor.update_parameters(slurm_array_parallelism=16,
                                slurm_mem_per_cpu="4GB",
                                slurm_cpus_per_task=8,
                                slurm_time="06:00:00",  # "24:00:00",
                                slurm_qos="6hours",
                                slurm_partition=GPU_PARTITION, #"pascal"
                                slurm_additional_parameters=dict(
                                    gres="gpu:1"
                                )
                                )
    print(GPU_PARTITION)

    jobs = []
    with executor.batch():
        # Baseline
        for seed in seeds:
            if args.rerun_failed:
                if check_rerun(rar=False, ot_rar=False, sampling_method=ST.none, n_samples=1):
                    continue
            model_settings, path_settings = build_settings(sampling_method=ST.none,
                                                           seed=seed,
                                                           n_samples_constraints=1,
                                                           rar=False,
                                                           ot_rar=False,
                                                           pde_weight=0.,
                                                           run_num=len(jobs))

            # pprint.pprint(model_settings)
            # loss_dict = main(model_settings, path_settings)
            job = executor.submit(main, model_settings, path_settings)
            jobs.append(job)

        for sampling_method, n_samples, rar, ot_rar in cartesian_product:
            if check_condition(rar=rar, ot_rar=ot_rar, sampling_method=sampling_method, n_samples=n_samples):
                continue

            if args.rerun_failed:
                if check_rerun(rar=rar, ot_rar=ot_rar, sampling_method=sampling_method, n_samples=n_samples):
                    continue

            # pde_weights = np.linspace(start_weight[sampling_method], end_weight[sampling_method], 7)
            # pde_weights = np.linspace(start_weight[sampling_method], end_weight[sampling_method], 4)
            pde_weights = pde_weight_list[sampling_method]

            if rar:
                rar_text = "-rar"
            elif ot_rar:
                rar_text = "-ot_rar"
            else:
                rar_text = ""

            print(
                f"{sampling_method}{rar_text} - {n_samples} - "
                f"{start_weight[sampling_method]:.2e} to {end_weight[sampling_method]:.2e} "
                f"over steps: [{', '.join([str(w) for w in pde_weights])}]")

            for pde_weight in pde_weights:
                for seed in seeds:
                    model_settings, path_settings = build_settings(sampling_method=sampling_method,
                                                                   seed=seed,
                                                                   n_samples_constraints=n_samples,
                                                                   rar=rar,
                                                                   ot_rar=ot_rar,
                                                                   pde_weight=pde_weight,
                                                                   run_num=len(jobs))


                    job = executor.submit(main, model_settings, path_settings)
                    jobs.append(job)
        print(len(jobs))
        if args.check:
            exit()
        # if args.rerun_failed:
        #     assert failed_runs.shape[0] == len(jobs)

    clean_up_jobs(jobs, f"sweep_exp_3d_{datetime.now().strftime('%Y_%m_%d_%H_%M_%S')}")
