"""
All these visualization tools have been specifically designed for this project. So, it can be challenging to use them for other projects.
"""


import os 
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt

from gym_recording_modified.playback import scan_recorded_traces, get_recordings


EXPLORATION_LABELS = ['(constant)', '(scheduling)']

def save_plot(addr):
    plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, 
                hspace = 0, wspace = 0)
    plt.savefig(addr, bbox_inches='tight', pad_inches = 0)

def episode_rewards_plot(path, title, save_dir=None):
    """
    This funtion will visualize the accumulated rewards through episodes
        path : str
            Path to the root of the directory which contains the results of the same environment
        title : str
            The title that will be shown at the top of the plot
        save_dir : str
            The address of the directory that you want to save this plot
    """
    
    plt.figure()

    for exp_technique in os.listdir(path): # This will go through different exploration techniques
        root_dir = os.path.join(path, exp_technique)
        for i, run_name in enumerate(os.listdir(root_dir)):  # This will go through results that for the exp_technique that have been gathered in a specific date
            run_dir = os.path.join(root_dir, run_name)
            df = pd.read_csv(os.path.join(run_dir, 'args.csv')).set_index('Unnamed: 0')       
            exploration_type = int(df.loc['exploration_schedule'])
            rewards = get_recordings(run_dir, only_reward=True)
            label = (exp_technique + EXPLORATION_LABELS[exploration_type])
            plt.plot(np.cumsum(list(map(lambda r: sum(r), rewards))), label=label)

    plt.title(title)
    plt.xlabel('Episodes')
    plt.ylabel('Accumulated Rewards')
    plt.legend()
    
    if save_dir is not None: save_plot(save_dir)
        
    
def step_rewards_plot(path, title, save_dir=None):
    """
    This funtion will visualize the accumulated rewards through episodes
        path : str
            Path to the root of the directory which contains the results of the same environment
        title : str
            The title that will be shown at the top of the plot
        save_dir : str
            The address of the directory that you want to save this plot
    """
    
    plt.figure()

    for exp_technique in os.listdir(path): # This will go through different exploration techniques
        root_dir = os.path.join(path, exp_technique)
        for i, run_name in enumerate(os.listdir(root_dir)): # This will go through results that for the exp_technique that have been gathered in a specific date
            run_dir = os.path.join(root_dir, run_name)
            df = pd.read_csv(os.path.join(run_dir, 'args.csv')).set_index('Unnamed: 0')       
            exploration_type = int(df.loc['exploration_schedule'])
            rewards = get_recordings(run_dir, only_reward=True)
            step_rewards = []
            for r in rewards:
                step_rewards += r
            label = (exp_technique + EXPLORATION_LABELS[exploration_type])
            plt.plot(np.cumsum(step_rewards), label=label)

    plt.title(title)
    plt.xlabel('Steps')
    plt.ylabel('Accumulated Rewards')
    plt.legend()
    
    if save_dir is not None: save_plot(save_dir)
