import random
from torch import manual_seed

import numpy as np

from k_level_policy_gradients.src.utils.slurm_launcher import (
    run_experiment,
    single_experiment,
)
from k_level_policy_gradients.src.experiments.multi_agent_experiment import (
    MultiAgentExperiment,
)
from k_level_policy_gradients.src.utils.time_profiling import Timer
from k_level_policy_gradients.src.utils.launcher_tools import restore_slurm_dictionary


@single_experiment
def experiment(**kwargs):
    seed = kwargs["seed"]
    random.seed(seed)
    np.random.seed(seed)
    manual_seed(seed)

    # Restore dicionaries modified by slurm
    if type(kwargs["agent_params"]) is list:
        kwargs["agent_params"] = restore_slurm_dictionary(kwargs["agent_params"])
    if type(kwargs["env_params"]) is list:
        kwargs["env_params"] = restore_slurm_dictionary(kwargs["env_params"])
    if type(kwargs["env_run_params"]) is list:
        kwargs["env_run_params"] = restore_slurm_dictionary(kwargs["env_run_params"])

    # Adjust kwargs (experiment_launcher sometimes changes object types)
    if kwargs["env_params"]["horizon"] == "None":
        kwargs["env_params"]["horizon"] = None
    if kwargs["agent_params"]["grad_norm_clip"] == "None":
        kwargs["agent_params"]["grad_norm_clip"] = None
    kwargs["agent_params"]["use_torch"] = kwargs["env_run_params"]["use_torch"]
    kwargs["agent_params"]["use_cuda"] = kwargs["env_run_params"]["use_cuda"]
    if "state_last_action" in kwargs["agent_params"]:
        kwargs["env_params"]["state_last_action"] = kwargs["agent_params"][
            "state_last_action"
        ]

    # Experiment
    experiment = MultiAgentExperiment(**kwargs)
    t = Timer(experiment.exp_logger)

    core = t.time_function(experiment.train_agents)
    t.time_function(experiment.evaluate, core)


if __name__ == "__main__":
    # Leave unchanged
    run_experiment(experiment)
