import os
import time
import csv

import click

import numpy as np
import pandas as pd
import torch

from rampwf.utils.importing import import_module_from_source
import json

from .data_processing import get_metadata_dictionary
from .data_processing import rollout
from .model_env import make_model_env_class
from .utils import load_system_env, save_experiment_metadata

def mbrl_run(agent_name, submission,
             n_epochs, min_epoch_steps, min_random_steps,
             episodic_update, initial_trace, seed, epoch_resume, root_output_dir):
    """Main script of model based RL loop."""

    np.random.seed(seed)
    if seed is not None:
        torch.manual_seed(seed)
    print(f'Random seed: {seed}')

    if min_random_steps is None:
        min_random_steps = min_epoch_steps

    print((f'Using {submission} as model and {agent_name} as '
           f'agent for {n_epochs} epochs with at least {min_epoch_steps} '
           f'steps per epoch and {min_random_steps} random steps.'))

    problem_module_path = 'problem.py'
    problem_module = import_module_from_source(problem_module_path, 'problem')

    system_env, system_env_object = load_system_env(
        pass_system_env_object=True)
    system_env.seed(seed)

    # reward function
    reward_module_path = 'reward_function.py'
    reward_module = import_module_from_source(
        reward_module_path, 'reward_function')
    reward_func = reward_module.reward_func

    # safety function
    safety_module_path = 'safety_function.py'
    try:
        safety_module = import_module_from_source(safety_module_path, 'safety_function')
        safety_func = safety_module.safety_function
    except:
        safety_func = None


    # agent
    agent_module_path = os.path.join('agents', agent_name + '.py')
    agent_module = import_module_from_source(agent_module_path, agent_name)
    agent_object = agent_module.Agent

    # metadata
    metadata_path = os.path.join('data', 'metadata.json')
    metadata = get_metadata_dictionary(metadata_path)

    # create a directory to store the results
    output_dir = os.path.join(root_output_dir,
        'submissions', submission, 'mbrl_outputs', agent_name, f'seed_{seed}')
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    save_experiment_metadata(root_output_dir, metadata, agent_name, submission)

    # initialize csv to store rollout times
    rollout_times_path = os.path.join(output_dir, 'rollout_times.csv')
    with open(rollout_times_path, 'a', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(['epoch_id', 'rollout_time'])

    # model
    if submission == 'real_system':
        model_env = system_env
    else:
        submission_path = os.path.join('submissions', submission)
        ModelEnv = make_model_env_class(system_env_object)
        model_env = ModelEnv(
            submission_path, problem_module, reward_func,
            metadata, output_dir, seed=seed, safety_func=safety_func)

    # retrieving feature names
    observation_names = metadata["observation"]
    action_names = metadata["action"]
    restart_name = metadata["restart_name"]
    reward_name = metadata["reward"]

    # states are also saved besides the observations in the trace for ease of
    # replay from the collected traces.
    # get the number of states
    system_env.reset()
    n_states = len(system_env.get_numpy_state())

    trace_header = (
        observation_names + action_names + reward_name +
        [restart_name] + ['epoch_id'] +
        [f'state_{i}' for i in range(n_states)] + ['cost'] + ['original_reward'])

    epoch_output_dir = os.path.join(output_dir, 'epoch_0')
    if initial_trace:
        # epoch 0 is the initial trace
        print('Epoch 0: Initial trace.')

        # no need to train if the model environment is the real environment
        if epoch_resume is not None:
            epoch_start = epoch_resume
        else:
            epoch_start = 1

        if hasattr(model_env, 'train_model'):
            model_env.train_model(epoch_start)

        agent = agent_object(model_env, epoch_output_dir=epoch_output_dir,
                             seed=seed)

    else:
        # random agent
        agent = agent_object(model_env, epoch_output_dir=epoch_output_dir,
                             random_action=True, seed=seed)
        epoch_start = 0

    for epoch in range(epoch_start, n_epochs):
        # use the agent on the real system, collect the trace to update the
        # model and update the agent using the updated model

        if epoch == 0:
            # random policy
            min_rollout_steps = min_random_steps
        else:
            min_rollout_steps = min_epoch_steps
            agent.random_action = False

        rollout_start_time = time.perf_counter()
        trace = rollout(
            system_env, len(action_names),
            epoch=epoch, min_epoch_steps=min_rollout_steps,
            agent=agent, episodic_update=episodic_update)
        rollout_time = time.perf_counter() - rollout_start_time

        with open(rollout_times_path, 'a', newline='') as f:
            writer = csv.writer(f)
            writer.writerow([epoch, rollout_time])

        # save new trace to disk
        epoch_output_dir = os.path.join(output_dir, f'epoch_{epoch}')
        if not os.path.exists(epoch_output_dir):
            os.makedirs(epoch_output_dir)
        trace_path = os.path.join(epoch_output_dir, 'trace.csv')
        trace_df = pd.DataFrame(data=trace, columns=trace_header)
        trace_df.to_csv(trace_path, index=False)

        if hasattr(system_env, 'world_states_pickle'):
            world_dir = os.path.join(epoch_output_dir, 'world_state')
            if not os.path.exists(world_dir):
                os.makedirs(world_dir)
            for i, world_sate in enumerate(system_env.world_states_pickle):
                with open(os.path.join(world_dir, f'{i}_world.pickle'), 'wb') as file:
                    file.write(world_sate)


        # update model if it remains epochs to compute.
        if epoch <= n_epochs - 2 and hasattr(model_env, 'train_model'):
            model_env.train_model(epoch=epoch)


@click.command()
@click.option('--agent-name', default='random_shooting', show_default=True,
              type=click.STRING, help="Agent.")
@click.option("--submission", default="real_system", show_default=True,
              type=click.STRING,
              help="Model submission. Choose 'real_system' if you want to "
              "use the real environment.")
@click.option("--n-epochs", default=100, show_default=True, type=click.INT,
              help="The number of epochs. If the submission is not the real "
              "system, the model is updated at each epoch. If initial-trace "
              "is set to True the first epoch is assumed to be the initial "
              "trace.")
@click.option("--min-epoch-steps", default=200, show_default=True,
              type=click.INT,
              help="The minimum number of steps for each epoch given that "
              "each epoch ends by a complete episode.")
@click.option("--min-random-steps", default=None, show_default=True,
              type=click.INT,
              help="The minimum number of steps done at the first epoch"
              " with the random policy if initial-trace is set to False. "
              "If None then it is equal to min-epoch-steps.")
@click.option("--episodic-update", default=False, show_default=True,
              type=click.BOOL,
              help="Whether to update the model after each episode such that "
              "one epoch is exaclty one episode.")
@click.option("--initial-trace", default=False, show_default=True,
              type=click.BOOL, help="Whether an initial trace is available. "
              "If True, the initial trace should be stored under trace.csv in "
              "submissions/<submission>/mbrl_outputs/<agent_name>/seed_<seed>/"
              "epoch_0/.")
@click.option("--seed", default=99999, show_default=True,
              help="Seed of the random number generator. Only the numpy and "
              "pytorch global random generators are seeded.")
@click.option("--epoch-resume", default=None, show_default=True,
              help="If we want to resume training, the epoch from where"
                   "we resume training.")
@click.option("--output_dir", default='', show_default=True,
              type=click.STRING, help=" Specify path to save the results in another directory")
def mbrl_run_command(agent_name, submission,
                     n_epochs, min_epoch_steps, min_random_steps,
                     episodic_update, initial_trace,
                     seed, epoch_resume, output_dir):
    return mbrl_run(
        agent_name, submission, n_epochs, min_epoch_steps, min_random_steps,
        episodic_update, initial_trace, seed, epoch_resume, output_dir
    )


if __name__ == "__main__":
    mbrl_run_command()
