"""Script used to train agents."""

import argparse
import os

import tonic
import yaml
import json
import sys
from types import SimpleNamespace
import numpy as np
from collections.abc import Mapping
import torch
from utils import compute_entropy_shoulder, plot_trajs, print_data

from cluster import read_params_from_cmdline, save_metrics_params


def recursively_dictify(args):
    """Need to handle cluster_utils param dicts here because they cause
    issues with yaml."""
    for k, v in args.items():
        if isinstance(v, tuple):
            return str(args)
        elif isinstance(v, Mapping):
            args = dict(args)
            args[k] = recursively_dictify(v)
            return dict(args)
        else:
            return dict(args)


def load_time(path):
    return torch.load(os.path.join(path, 'checkpoints/time.pt'))


def prepare_params():
    if sys.argv[-1] == '0':
        with open('./param_files/default_tonic.json', 'r') as f:
            params = json.load(f)
        orig_params, params = get_params(params)
    else:
        orig_params = read_params_from_cmdline()
        _, params = prepare_cluster(orig_params)
    return orig_params, params


def prepare_cluster(params):
    # os.environ["CUDA_VISIBLE_DEVICES"]=""
    os.makedirs(params.working_dir, exist_ok=True)
    return get_params(params)


def get_params(params):
    orig_params = params.copy()
    for key, val in params.items():
        if type(params[key]) == dict:
            params[key] = SimpleNamespace(**val)
    params = SimpleNamespace(**params)
    return orig_params, params


def post_run(orig_params, avg_return, action_buff, state_buff):
    entropy = compute_entropy_shoulder(state_buff)
    print_data(entropy, state_buff)
    if sys.argv[-1] == '0':
        plot_trajs(action_buff, state_buff)
    else:
        metrics = {'avg_return': avg_return,
                   'entropy': entropy}
        save_metrics_params(metrics, orig_params)


def func(env, preid=5, parallel=1, sequential=1):
    print(env)
    if 'ostrich' in env:
        return lambda: eval(env)

    def build_env(identifier=0):
        id_eff = preid * (parallel * sequential) + identifier
        build = env[:-1]
        build = build + f',identifier={id_eff})'
        return eval(build)

    return build_env


def maybe_load_checkpoint(header, agent, environment, trainer, time_dict, checkpoint_path, checkpoint, eff_path):
    if os.path.isdir(checkpoint_path):
        tonic.logger.log(f'Loading experiment from {eff_path}')
        try:
            time_dict = load_time(eff_path)
        except:
            tonic.logger.log('Error in loading, starting fresh')
            return header, agent, environment, trainer, time_dict, checkpoint_path

        # Use no checkpoint, the agent is freshly created.
        if checkpoint == 'none':
            tonic.logger.log('Not loading any weights')

        else:

            # List all the checkpoints.
            checkpoint_ids = []
            for file in os.listdir(checkpoint_path):
                if file[:5] == 'step_':
                    checkpoint_id = file.split('.')[0]
                    checkpoint_ids.append(int(checkpoint_id[5:]))

            if checkpoint_ids:
                # Use the last checkpoint.
                if checkpoint == 'last':
                    checkpoint_id = max(checkpoint_ids)
                    checkpoint_path = os.path.join(
                        checkpoint_path, f'step_{checkpoint_id}')

                # Use the specified checkpoint.
                else:
                    checkpoint_id = int(checkpoint)
                    if checkpoint_id in checkpoint_ids:
                        checkpoint_path = os.path.join(
                            checkpoint_path, f'step_{checkpoint_id}')
                    else:
                        tonic.logger.error(f'Checkpoint {checkpoint_id} '
                                           f'not found in {checkpoint_path}')
                        checkpoint_path = None
            else:
                tonic.logger.error(f'No checkpoint found in {checkpoint_path}')
                checkpoint_path = None

        # Load the experiment configuration.
        arguments_path = os.path.join(eff_path, 'config.yaml')
        with open(arguments_path, 'r') as config_file:
            config = yaml.load(config_file, Loader=yaml.FullLoader)
        config = argparse.Namespace(**config)

        header = header or config.header
        agent = agent or config.agent
        environment = environment or config.test_environment
        environment = environment or config.environment
        trainer = trainer or config.trainer
        return header, agent, environment, trainer, time_dict, checkpoint_path

    else:
        checkpoint_path = None
        return header, agent, environment, trainer, time_dict, checkpoint_path


def train(
        orig_params, header, agent, environment, test_environment, trainer, before_training,
        after_training, parallel, sequential, seed, name, environment_name,
        checkpoint, path, preid=0, env_args=None
):
    """Trains an agent on an environment."""
    # Capture the arguments to save them, e.g. to play with the trained agent.
    # TODO fix this mess and do it properly
    args = dict(locals())
    del args['orig_params']
    if args['env_args']:
        args['env_args'] = dict(args['env_args'])
        if 'target' in args['env_args']:
            args['env_args']['target'] = list(args['env_args']['target'])
        if 'rew_args' in args['env_args']:
            args['env_args']['rew_args'] = dict(args['env_args']['rew_args'])
    # args = recursively_dictify(args)

    eff_path = os.path.join(path, environment_name, name)
    # Process the checkpoint path same way as in tonic.play
    tonic.logger.log('correct branch and commit')
    checkpoint_path = os.path.join(eff_path, 'checkpoints')
    time_dict = {'steps': 0,
                'epochs': 0,
                'episodes': 0}
    header, agent, environment, trainer, time_dict, checkpoint_path = maybe_load_checkpoint(header, agent, environment,
                                                                                            trainer, time_dict,
                                                                                            checkpoint_path, checkpoint,
                                                                                           eff_path)
    # Run the header first, e.g. to load an ML framework.
    if header:
        exec(header)

    # Build the training environment.
    _environment = environment

    environment = tonic.environments.distribute(
        func(_environment, preid, parallel, sequential), parallel, sequential, env_args=env_args)
    environment.initialize(seed=seed)
    # Build the testing environment.
    _test_environment = test_environment if test_environment else _environment
    test_environment = tonic.environments.distribute(
        func(_test_environment, preid + 1000000), env_args=env_args)
    test_environment.initialize(seed=seed + 1000000)

    # Build the agent.
    if not agent:
        raise ValueError('No agent specified.')
    agent = eval(agent)
    if 'mpo_args' in orig_params:
        agent.set_params(**orig_params['mpo_args'])
    agent.initialize(
        observation_space=environment.observation_space,
        action_space=environment.action_space, seed=seed)
    if hasattr(agent, 'expl') and 'DEP' in orig_params:
        agent.expl.set_params(orig_params['DEP'])
    # Load the weights of the agent form a checkpoint.
    if checkpoint_path:
        agent.load(checkpoint_path)

    # Initialize the logger to save data to the path environment/name/seed.
    if not environment_name:
        if hasattr(test_environment, 'name'):
            environment_name = test_environment.name
        else:
            environment_name = test_environment.__class__.__name__
    if not name:
        if hasattr(agent, 'name'):
            name = agent.name
        else:
            name = agent.__class__.__name__
        if parallel != 1 or sequential != 1:
            name += f'-{parallel}x{sequential}'
    # path = os.path.join(environment_name, name, str(seed))
    eff_path = os.path.join(path, environment_name, name)
    # args = args.copy().pop('env_args')
    # args = {'test': 0}
    tonic.logger.initialize(eff_path, script_path=__file__, config=args)
    if checkpoint_path:
        tonic.logger.load(checkpoint_path)

    # Build the trainer.
    trainer = trainer or 'tonic.Trainer()'
    trainer = eval(trainer)
    trainer.initialize(
        agent=agent, environment=environment,
        test_environment=test_environment)

    # Run some code before training.
    if before_training:
        exec(before_training)

    # Train.
    try:
        scores = trainer.run(orig_params, **time_dict)
    except Exception as e:
        tonic.logger.log(f'trainer failed at the end. Exception: {e}')

    # Run some code after training.
    if after_training:
        exec(after_training)
    # return scores


if __name__ == '__main__':
    try:
        torch.zeros((0,1), device='cuda')
        torch.set_default_tensor_type("torch.cuda.FloatTensor")
    except:
        print('No cuda detected, running on cpu')
    orig_params, params = prepare_params()
    train_params = dict(orig_params['tonic'])
    train_params['path'] = orig_params['working_dir']
    train_params['preid'] = orig_params['id']
    if 'env_args' in orig_params or 'env_args' in train_params:
        train_params['env_args'] = orig_params['env_args'] if 'env_args' in orig_params else train_params['env_args']
    train(orig_params, **train_params)
    # metrics = {'test/episode_score/mean': np.mean(scores)}
    # save_metrics_params(metrics, orig_params)
