
import os
import time
from llrl.utils.chart_utils import lifelong_plot
from llrl.utils.save import lifelong_save, save_script
from simple_rl.experiments import Experiment

from torch.utils.tensorboard import SummaryWriter
from statistics import mean


def run_agents_lifelong(
        agents,
        mdp_distribution,
        name_identifier=None,
        n_instances=1,
        n_tasks=5,
        n_episodes=1,
        n_steps=100,
        parallel_run=False,
        n_processes=None,
        clear_old_results=True,
        track_disc_reward=False,
        reset_at_terminal=False,
        cumulative_plot=True,
        dir_for_plot='results',
        verbose=False,
        do_run=True,
        do_plot=False,
        confidence=.9,
        open_plot=False,
        plot_title=True,
        plot_legend=True,
        episodes_moving_average=False,
        episodes_ma_width=10,
        tasks_moving_average=False,
        tasks_ma_width=10,
        latex_rendering=False,
        log_dir = "logs"
):
    """
    Runs each agent on the MDP distribution according to the given parameters.
    If @mdp_distribution has a non-zero horizon, then gamma is set to 1 and @steps is ignored.

    :param agents: (list)
    :param mdp_distribution: (MDPDistribution)
    :param name_identifier: (str)
    :param n_instances: (int)
    :param n_tasks: (int)
    :param n_episodes: (int)
    :param n_steps: (int)
    :param parallel_run: (bool)
    :param n_processes: (int)
    :param clear_old_results: (bool)
    :param track_disc_reward: (bool) If true records and plots discounted reward, discounted over episodes.
    So, if each episode is 100 steps, then episode 2 will start discounting as though it's step 101.
    :param reset_at_terminal: (bool)
    :param cumulative_plot: (bool)
    :param dir_for_plot: (str)
    :param verbose: (bool)
    :param do_run: (bool)
    :param do_plot: (bool)
    :param confidence: (float)
    :param open_plot: (bool)
    :param plot_title: (bool)
    :param plot_legend: (bool)
    :param episodes_moving_average: (bool)
    :param episodes_ma_width: (int)
    :param tasks_moving_average: (bool)
    :param tasks_ma_width: (int)
    :param latex_rendering: (bool)
    :return:
    """
    exp_params = {"samples": n_tasks, "episodes": n_episodes, "steps": n_steps}
    # name_identifier=name_identifier
    experiment = Experiment(agents=agents, mdp=mdp_distribution, params=exp_params,
                            is_episodic=n_episodes > 1, is_lifelong=True, clear_old_results=clear_old_results,
                            track_disc_reward=track_disc_reward, cumulative_plot=cumulative_plot,
                            dir_for_plot=dir_for_plot)
    path = experiment.exp_directory
    save_script(path)
    print(path)

    print("Running experiment:\n" + str(experiment))

    # Sample tasks
    tasks = []
    for _ in range(n_tasks):
        tasks.append(mdp_distribution.sample())
    n_agents = len(agents)

    # Run
    if do_run:
        if parallel_run:
            pass
        else:
            for i in range(n_agents):
                lifelong_save(init=True, path=path, agent=agents[i])
                for j in range(n_instances):
                    run_agent_lifelong(agents[i], experiment, j, n_tasks, n_episodes, n_steps, tasks, track_disc_reward,
                                       reset_at_terminal, path,n_instances, verbose,log_dir)
    
    # Plot
    if do_plot:
        lifelong_plot(agents, path, n_tasks, n_episodes, confidence, open_plot, plot_title, plot_legend,
                      episodes_moving_average=episodes_moving_average, episodes_ma_width=episodes_ma_width,
                      tasks_moving_average=tasks_moving_average, tasks_ma_width=tasks_ma_width,
                      latex_rendering=latex_rendering)


def run_agent_lifelong(agent, experiment, instance_number, n_tasks, n_episodes, n_steps, tasks, track_disc_reward, reset_at_terminal,
                       path, verbose,n_instances,log_dir):
    """
    :param agent: ()
    :param experiment: ()
    :param instance_number: (int)
    :param n_tasks: (int)
    :param n_episodes: (int)
    :param n_steps: (int)
    :param tasks: (list)
    :param track_disc_reward: (bool)
    :param reset_at_terminal: (bool)
    :param path: (str)
    :param verbose: (bool)
    :return: None
    """
    agent.re_init()  # re-initialize before each instance
    data = {'returns_per_tasks': [], 'discounted_returns_per_tasks': []}
    dis = []

    start = time.clock()
    task_dir = os.path.join(log_dir, "task_instance_number{}_agent{}_{}".format(instance_number, agent.name,time.strftime("%Y%m%d-%H%M%S")))
    writer_task = SummaryWriter(log_dir=task_dir)
    for i in range(1, n_tasks + 1):
        # print("    Experience task " + str(i) + " / " + str(n_tasks))
        task = tasks[i - 1]  # task selection
        
        # Run on task
        # print(reset_at_terminal)
        # print(agent, task, n_episodes, n_steps, experiment, verbose, track_disc_reward,
        #     reset_at_terminal)
        agent.init_task(task)
        dis.append(agent.dis)
        print("    Task " + str(i) + " / " + str(n_tasks) + " distance " + str(dis))
        experiment_dir = os.path.join(log_dir, "experiment_task{}_instance_number{}_agent{}_{}".format(i, instance_number, agent.name,time.strftime("%Y%m%d-%H%M%S")))
        writer = SummaryWriter(log_dir=experiment_dir)
        _, _, returns, discounted_returns = run_single_agent_on_mdp(agent, task, n_episodes, n_steps, writer,experiment, verbose=verbose, track_disc_reward=track_disc_reward,reset_at_terminal=reset_at_terminal, resample_at_terminal=False)

        agent.set_task()

        # Store
        data['returns_per_tasks'].append(returns)
        data['discounted_returns_per_tasks'].append(discounted_returns)
        writer_task.add_scalar('task_/return', mean(returns), i)
        writer_task.add_scalar('task_/discounted_return', mean(discounted_returns), i)
        writer.close()

        # Reset the agent
        agent.reset()
    print("    Total time elapsed: " + str(round(time.clock() - start, 3)))
    writer_task.close()

    # Save
    lifelong_save(init=False, path=path, agent=agent, data=data, instance_number=instance_number)


def run_single_agent_on_mdp(agent, mdp, n_episodes, n_steps,writer, experiment=None, track_disc_reward=False,
                            reset_at_terminal=False, resample_at_terminal=False, verbose=False):
    """
    :param agent:
    :param mdp:
    :param n_episodes:
    :param n_steps:
    :param experiment:
    :param track_disc_reward:
    :param reset_at_terminal:
    :param resample_at_terminal:
    :param verbose:
    :return:
    """
    if reset_at_terminal and resample_at_terminal:
        raise ValueError("ExperimentError: Can't have reset_at_terminal and resample_at_terminal set to True.")

    return_per_episode = [0] * n_episodes

    discounted_return_per_episode = [0] * n_episodes
    gamma = mdp.get_gamma()

    # print(self.agent)

    # For each episode.
    for episode in range(1, n_episodes + 1):
        cumulative_episodic_reward = 0.

        # if verbose:
        #     print("      Episode", str(episode), "/", str(n_episodes))
        
        # Compute initial state/reward.
        state = mdp.get_init_state()
        reward = 0.

        for step in range(1, n_steps + 1):

            # step time
            step_start = time.clock()

            # Compute the agent's policy.
            action = agent.act(state, reward, mdp)

            # Terminal check.
            if state.is_terminal():
                if n_episodes == 1 and not reset_at_terminal and experiment is not None and action != "terminate":
                    # Self loop if we're not episodic or resetting and in a terminal state.
                    experiment.add_experience(agent, state, action, 0, state, time_taken=time.clock()-step_start)
                    continue
                break
            
            # Execute in MDP.
            reward, next_state = mdp.execute_agent_action(action)

            # Track value.
            return_per_episode[episode - 1] += reward
            discounted_return_per_episode[episode - 1] += reward * (gamma ** step)
            cumulative_episodic_reward += reward

            # Record the experience.
            if experiment is not None:
                reward_to_track = mdp.get_gamma()**(step + 1 + episode*n_steps) * reward if track_disc_reward else reward
                reward_to_track = round(reward_to_track, 5)
                experiment.add_experience(agent, state, action, reward_to_track, next_state,
                                          time_taken=time.clock() - step_start)

            if next_state.is_terminal():
                if reset_at_terminal:
                    # Reset the MDP.
                    next_state = mdp.get_init_state()
                    mdp.reset()
                elif resample_at_terminal and step < n_steps:
                    mdp.reset()
                    return True, step, return_per_episode, discounted_return_per_episode

            # Update pointer.
            state = next_state
        
        # A final update.
        _ = agent.act(state, reward, mdp)

        # Tensorboard
        writer.add_scalar('return', cumulative_episodic_reward, episode)
        writer.add_scalar('discounted_return', sum(discounted_return_per_episode), episode)
        # print(cumulative_episodic_reward)


        # Process experiment info at end of episode.
        if experiment is not None:
            experiment.end_of_episode(agent)

        # Reset the MDP, tell the agent the episode is over.
        mdp.reset()
        agent.end_of_episode()

    # Process that learning instance's info at end of learning.
    if experiment is not None:
        experiment.end_of_instance(agent)

    return False, n_steps, return_per_episode, discounted_return_per_episode


            
