"""Script used to train agents."""
import argparse
import json
import os
import sys
import traceback
from collections.abc import Mapping
from types import SimpleNamespace
import json
import tonic
import torch
import yaml

os.environ['CUDA_VISIBLE_DEVICES']=""

try:
    from cluster import read_params_from_cmdline
    from cluster import save_metrics_params
    CLUSTER = 1
except ImportError:
    print('No cluster utils detected, proceeding without it')
    CLUSTER = 0


from .utils import print_data


def recursively_dictify(args):
    """Need to handle cluster_utils param dicts here because they cause
    issues with yaml."""
    for k, v in args.items():
        if isinstance(v, tuple):
            return str(args)
        elif isinstance(v, Mapping):
            args = dict(args)
            args[k] = recursively_dictify(v)
            return dict(args)
        else:
            return dict(args)


def load_time(path):
    return torch.load(os.path.join(path, "checkpoints/time.pt"))


def prepare_params():
    if sys.argv[-1] == "0":
        with open("./param_files/default_tonic.json", "r") as f:
            orig_params = json.load(f)
    elif CLUSTER == 1:
        orig_params = read_params_from_cmdline()
    else:
        f = open(sys.argv[-1], 'r')
        orig_params = json.load(f)
    params = prepare_cluster(orig_params)
    return orig_params, params


def prepare_cluster(orig_params):
    # os.environ["CUDA_VISIBLE_DEVICES"]=""
    params = get_params(orig_params)
    os.makedirs(params.working_dir, exist_ok=True)
    return params


def get_params(orig_params):
    params = orig_params.copy()
    for key, val in params.items():
        if type(params[key]) == dict:
            params[key] = SimpleNamespace(**val)
    params = SimpleNamespace(**params)
    return params


def post_run(orig_params, avg_return, action_buff, state_buff):
    print_data(entropy, state_buff)
    if sys.argv[-1] == "0":
        pass
    else:
        metrics = {"avg_return": avg_return, "entropy": entropy}
        save_metrics_params(metrics, orig_params)


def maybe_load_checkpoint(
    header,
    agent,
    environment,
    trainer,
    time_dict,
    checkpoint_path,
    checkpoint,
    eff_path,
):
    if os.path.isdir(checkpoint_path):
        tonic.logger.log(f"Loading experiment from {eff_path}")
        try:
            time_dict = load_time(eff_path)
        except Exception as e:
            tonic.logger.log(f"Error in loading, starting fresh. Error was: {e}")
            return header, agent, environment, trainer, time_dict, checkpoint_path

        # Use no checkpoint, the agent is freshly created.
        if checkpoint == "none":
            tonic.logger.log("Not loading any weights")

        else:

            # List all the checkpoints.
            checkpoint_ids = []
            for file in os.listdir(checkpoint_path):
                if file[:5] == "step_":
                    checkpoint_id = file.split(".")[0]
                    checkpoint_ids.append(int(checkpoint_id[5:]))

            if checkpoint_ids:
                # Use the last checkpoint.
                if checkpoint == "last":
                    checkpoint_id = max(checkpoint_ids)
                    checkpoint_path = os.path.join(
                        checkpoint_path, f"step_{checkpoint_id}"
                    )

                # Use the specified checkpoint.
                else:
                    checkpoint_id = int(checkpoint)
                    if checkpoint_id in checkpoint_ids:
                        checkpoint_path = os.path.join(
                            checkpoint_path, f"step_{checkpoint_id}"
                        )
                    else:
                        tonic.logger.error(
                            f"Checkpoint {checkpoint_id} not found in {checkpoint_path}"
                        )
                        checkpoint_path = None
            else:
                tonic.logger.error(f"No checkpoint found in {checkpoint_path}")
                checkpoint_path = None

        # Load the experiment configuration.
        arguments_path = os.path.join(eff_path, "config.yaml")
        with open(arguments_path, "r") as config_file:
            config = yaml.load(config_file, Loader=yaml.FullLoader)
        config = argparse.Namespace(**config)

        header = header or config.header
        agent = agent or config.agent
        environment = environment or config.test_environment
        environment = environment or config.environment
        trainer = trainer or config.trainer
        return header, agent, environment, trainer, time_dict, checkpoint_path

    else:
        checkpoint_path = None
        return header, agent, environment, trainer, time_dict, checkpoint_path


def train(
    orig_params,
    header,
    agent,
    environment,
    test_environment,
    trainer,
    before_training,
    after_training,
    parallel,
    sequential,
    seed,
    name,
    environment_name,
    checkpoint,
    path,
    preid=0,
    env_args=None,
):
    """Trains an agent on an environment."""
    # Capture the arguments to save them, e.g. to play with the trained agent.
    # TODO fix this mess and do it properly
    args = dict(locals())
    del args["orig_params"]
    if args["env_args"]:
        args["env_args"] = dict(args["env_args"])
        if "target" in args["env_args"]:
            args["env_args"]["target"] = list(args["env_args"]["target"])
        if "rew_args" in args["env_args"]:
            args["env_args"]["rew_args"] = dict(args["env_args"]["rew_args"])
    # args = recursively_dictify(args)

    eff_path = os.path.join(path, environment_name, name)
    # Process the checkpoint path same way as in tonic.play
    tonic.logger.log("correct branch and commit")
    checkpoint_path = os.path.join(eff_path, "checkpoints")
    time_dict = {"steps": 0, "epochs": 0, "episodes": 0}
    (
        header,
        agent,
        environment,
        trainer,
        time_dict,
        checkpoint_path,
    ) = maybe_load_checkpoint(
        header,
        agent,
        environment,
        trainer,
        time_dict,
        checkpoint_path,
        checkpoint,
        eff_path,
    )
    # Run the header first, e.g. to load an ML framework.
    if header:
        exec(header)

    # Build the training environment.
    _environment = environment

    environment = tonic.environments.distribute(
        dict(env=_environment, preid=preid, parallel=parallel, sequential=sequential),
        parallel,
        sequential,
        env_args=env_args,
    )
    environment.initialize(seed=seed)
    # Build the testing environment.
    _test_environment = test_environment if test_environment else _environment
    test_environment = tonic.environments.distribute(
        dict(env=_test_environment, preid=preid + 1000000), env_args=env_args
    )
    test_environment.initialize(seed=seed + 1000000)

    # Build the agent.
    if not agent:
        raise ValueError("No agent specified.")
    agent = eval(agent)
    if "mpo_args" in orig_params:
        agent.set_params(**orig_params["mpo_args"])
    agent.initialize(
        observation_space=environment.observation_space,
        action_space=environment.action_space,
        seed=seed,
    )
    if hasattr(agent, "expl") and "DEP" in orig_params:
        agent.expl.set_params(orig_params["DEP"])
    # Load the weights of the agent form a checkpoint.
    if checkpoint_path:
        agent.load(checkpoint_path)

    # Initialize the logger to save data to the path environment/name/seed.
    if not environment_name:
        if hasattr(test_environment, "name"):
            environment_name = test_environment.name
        else:
            environment_name = test_environment.__class__.__name__
    if not name:
        if hasattr(agent, "name"):
            name = agent.name
        else:
            name = agent.__class__.__name__
        if parallel != 1 or sequential != 1:
            name += f"-{parallel}x{sequential}"
    # path = os.path.join(environment_name, name, str(seed))
    eff_path = os.path.join(path, environment_name, name)
    # args = args.copy().pop('env_args')
    # args = {'test': 0}
    tonic.logger.initialize(eff_path, script_path=__file__, config=args)
    if checkpoint_path:
        tonic.logger.load(checkpoint_path)

    # Build the trainer.
    trainer = trainer or "tonic.Trainer()"
    trainer = eval(trainer)
    trainer.initialize(
        agent=agent, environment=environment, test_environment=test_environment
    )

    # Run some code before training.
    if before_training:
        exec(before_training)

    # Train.
    try:
        scores = trainer.run(orig_params, **time_dict)
    except Exception as e:
        tonic.logger.log("trainer failed. Exception: ")
        print(traceback.format_exc())

    # Run some code after training.
    if after_training:
        exec(after_training)
    # return scores


if __name__ == "__main__":
    try:
        torch.zeros((0, 1), device="cuda")
        torch.set_default_tensor_type("torch.cuda.FloatTensor")
    except Exception as e:
        print(f"No cuda detected, running on cpu: {e}")
    orig_params, params = prepare_params()
    train_params = dict(orig_params["tonic"])
    train_params["path"] = orig_params["working_dir"]
    train_params["preid"] = orig_params["id"]
    if "env_args" in orig_params or "env_args" in train_params:
        train_params["env_args"] = (
            orig_params["env_args"]
            if "env_args" in orig_params
            else train_params["env_args"]
        )
    train(orig_params, **train_params)

    # metrics = {'test/episode_score/mean': np.mean(scores)}
    # save_metrics_params(metrics, orig_params)
