"""
Use pythohn gather.py to show results of different algorithms in one plot
"""
import numpy as np
from copy import deepcopy
import os
import time
import json
from plot import regret_plot
from ipdb import launch_ipdb_on_exception
import argparse

storage_counter = 1


def plot(env_name, agents, storage_counter):

    alg_storage = {agent: storage_counter for agent in agents}

    alg_storage = {
        'QLearningAgent': str(storage_counter),
        'HQLAgent': str(storage_counter),
        'ULCAgent': str(storage_counter),
        'SSPBernsteinAgent': str(storage_counter),
        'UCSSPAgent': str(storage_counter),
        'SVIAgent': str(storage_counter),
        'EBAgent': str(storage_counter),
    }

    legends = {
        'QLearningAgent': 'Q-learning with $\epsilon$-greedy',
        'HQLAgent': 'HQL-SSP',
        'ULCAgent': 'ULC-SSP',
        'SSPBernsteinAgent': 'Bernstein-SSP',
        'UCSSPAgent': 'UC-SSP',
        'SVIAgent': 'SVI-SSP',
        'EBAgent': 'EB-SSP',
    }

    save_directory = 'plots/{}-gather'.format(env_name)

    regret_plot(environment_name=env_name, agents=agents, alg_storage=alg_storage, legends=legends,
                save_directory=save_directory)
            

def show_time(env_name, agents, storage_counter):
    for agent in agents:
        filename = agent + '_' + str(storage_counter)
        read_directory = os.path.join('log', env_name, filename)
        with open(os.path.join(read_directory, 't_stats')) as f:
            t_stats = json.load(f)
            for k, v in t_stats.items():
                if v[0] > 0:
                    tot_t = v[0]
                    break
        n = 0
        for name in sorted(os.listdir(read_directory)):
            if name.isdigit():
                n += 1
        print('{}: {}'.format(agent, tot_t / n)) # show average update time


if __name__ == '__main__':
    with launch_ipdb_on_exception():
        env_name = 'RandomMDPEnv'
        #env_name = 'GridWorldEnv'

        agents = [
            #'QLearningAgent',
            #'HQLAgent',
            #'ULCAgent',
            #'SSPBernsteinAgent',
            #'UCSSPAgent',
            #'SVIAgent',
            'EBAgent',
        ]

        plot(env_name, agents, storage_counter=storage_counter)
        show_time(env_name, agents, storage_counter=storage_counter)
