# Made by: Giuseppe PAOLO
# Date: 1/5/2022

import os
import time
import csv

import click

import numpy as np
import pandas as pd
import torch

from rampwf.utils.importing import import_module_from_source
from rampwf.utils import unpickle_trained_model

from .data_processing import get_metadata_dictionary
from .data_processing import rollout
from .model_env import make_model_env_class
from .utils import load_system_env


def model_agent_eval(agent_name, submission, min_epoch_steps, seed, output_dir, epoch=None, safety_module_path=None):
    """Main script of model based RL eval"""
    np.random.seed(seed)
    if seed is not None:
        torch.manual_seed(seed)
    print(f'Random seed: {seed}')

    print(f'Evaluating {submission} as model and {agent_name} as agent with at least {min_epoch_steps} steps.')

    problem_module_path = 'problem.py'
    problem_module = import_module_from_source(problem_module_path, 'problem')

    system_env, system_env_object = load_system_env(pass_system_env_object=True)
    system_env.seed(seed)

    # reward function
    reward_module_path = 'reward_function.py'
    reward_module = import_module_from_source(reward_module_path, 'reward_function')
    reward_func = reward_module.reward_func

    # safety function
    if safety_module_path is None:
        safety_module_path = 'safety_function.py'
    elif '.py' not in safety_module_path:
        safety_module_path += '.py'
    try:
        safety_module = import_module_from_source(safety_module_path, 'safety_function')
        safety_func = safety_module.safety_function
        system_env.safety_func = safety_func # Give the system env the new safety func
    except:
        safety_func = None

    # agent
    agent_module_path = os.path.join('agents', agent_name + '.py')
    agent_module = import_module_from_source(agent_module_path, agent_name)
    agent_object = agent_module.Agent

    # metadata
    metadata_path = os.path.join('data', 'metadata.json')
    metadata = get_metadata_dictionary(metadata_path)

    # Get epoch directory # TODO find best epoch
    loading_dir = os.path.join('submissions', submission, 'mbrl_outputs', agent_name, f'seed_{seed}')
    if epoch is None:
        epochs = [max([int(ep_dir.split("_")[1]) for ep_dir in os.listdir(loading_dir) if 'epoch' in ep_dir])]
    elif epoch == -1:
        epochs = [int(ep_dir.split("_")[1]) for ep_dir in os.listdir(loading_dir) if 'epoch' in ep_dir]
    else:
        epochs = [epoch]

    for epoch in epochs:
        epoch_dir = 'epoch_{}'.format(epoch)
        # create a directory to store the results
        if output_dir is None:
            output_dir = os.path.join('submissions', submission, 'mbrl_outputs', agent_name, f'seed_{seed}', 'evaluation')
            output_dir = os.path.join(output_dir, epoch_dir)
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        # initialize csv to store rollout times
        rollout_times_path = os.path.join(output_dir, 'rollout_times.csv')
        with open(rollout_times_path, 'a', newline='') as f:
            writer = csv.writer(f)
            writer.writerow(['epoch_id', 'rollout_time'])

        # model
        if submission == 'real_system':
            model_env = system_env
        else:
            submission_path = os.path.join('submissions', submission)
            ModelEnv = make_model_env_class(system_env_object)
            model_env = ModelEnv(submission_path, problem_module, reward_func,
                                 metadata, output_dir,
                                 seed=seed, safety_func=safety_func)
            # Load trained model and add it to the model env
            model_env.trained_model = unpickle_trained_model(os.path.join(loading_dir, epoch_dir),
                                                             trained_model_name='trained_submission.pkl',
                                                             is_silent=True)

        # retrieving feature names
        observation_names = metadata["observation"]
        action_names = metadata["action"]
        restart_name = metadata["restart_name"]
        reward_name = metadata["reward"]

        # states are also saved besides the observations in the trace for ease of
        # replay from the collected traces.
        # get the number of states
        system_env.reset()
        n_states = len(system_env.get_numpy_state())

        trace_header = (observation_names + action_names + reward_name + [restart_name] + ['epoch_id'] +
                        [f'state_{i}' for i in range(n_states)] + ['cost'])

        agent = agent_object(model_env, epoch_output_dir=output_dir, random_action=False, seed=seed)

        min_rollout_steps = min_epoch_steps

        rollout_start_time = time.perf_counter()
        trace = rollout(system_env, len(action_names), epoch=epoch,
                        min_epoch_steps=min_rollout_steps, agent=agent,
                        episodic_update=False)
        rollout_time = time.perf_counter() - rollout_start_time

        with open(rollout_times_path, 'a', newline='') as f:
            writer = csv.writer(f)
            writer.writerow([epoch, rollout_time])

        # save new trace to disk
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        trace_path = os.path.join(output_dir, 'trace.csv')
        trace_df = pd.DataFrame(data=trace, columns=trace_header)
        trace_df.to_csv(trace_path, index=False)


@click.command()
@click.option('--agent-name', default='random_shooting', show_default=True,
              type=click.STRING, help="Agent.")
@click.option("--submission", default="real_system", show_default=True,
              type=click.STRING,
              help="Model submission. Choose 'real_system' if you want to "
              "use the real environment.")
@click.option("--min-epoch-steps", default=200, show_default=True,
              type=click.INT,
              help="The minimum number of steps for each epoch given that "
              "each epoch ends by a complete episode.")
@click.option("--seed", default=99999, show_default=True,
              help="Seed of the random number generator. Only the numpy and "
              "pytorch global random generators are seeded.")
@click.option("--output_dir", default=None, show_default=True,
              type=click.STRING, help=" Specify path to save the results in another directory")
@click.option("--epoch", default=None, show_default=True, type=click.INT,
              help="The epoch to evaluate. If not specified, it will select the last epoch present and evaluate"
                   "that one. If -1 it will evaluate all the epochs present.")
@click.option("--safety-module-path", default=None, show_default=True, type=click.STRING,
              help="The safety cost module from which to load the new cost function")
def mbrl_eval_command(agent_name, submission, min_epoch_steps, seed, output_dir, epoch, safety_module_path):
    return model_agent_eval(agent_name, submission, min_epoch_steps, seed, output_dir, epoch, safety_module_path)


if __name__ == "__main__":
    mbrl_eval_command()