import json
import numpy as np
from copy import deepcopy
from environment import RandomMDPEnv, GridWorldEnv
from agent import QLearningAgent, HQLAgent, ULCAgent, SSPBernsteinAgent, UCSSPAgent, SVIAgent, EBAgent
from utils import timeit
import os
import time
from plot import regret_plot
from collections import Counter
import argparse

SEED = 6
K = 3000
storage_counter = 1
nb_runs = 500


class Runner(object):

    def __init__(self, agent, env, nb_runs=10, nb_episodes=10000, filename=None, seed=6):
        self.agent = agent
        self.env = env
        self.nb_runs = nb_runs
        self.nb_episodes = nb_episodes
        self.seed = seed

        if filename:
            self.save_directory = os.path.join('log', filename)
            if not os.path.exists(self.save_directory):
                os.makedirs(self.save_directory)
        else:
            self.save_directory = None

    def run(self):
        if self.save_directory:
            with open(os.path.join(self.save_directory, 'info.txt'), 'w') as file:
                file.write('*'*10 + ' General Info ' + '*'*10 + '\n')
                file.write('seed = {0}\n'.format(self.seed))
                file.write('number of episodes = {0}\n'.format(self.nb_episodes))
                file.write('number of runs = {0}\n'.format(self.nb_runs))
                file.write('*'*10 + ' Agent Info ' + '*'*10 + '\n')
                file.write(self.agent.info())
                file.write('*'*10 + ' Environment Info ' + '*'*10 + '\n')
                file.write(self.env.info())
        runs = []
        for exp in range(self.nb_runs):
            print('======= Experiment {0} ======='.format(exp))
            state = self.env.reset()
            self.agent.reset()
            regrets = [] # regret buffer, each item in regrets is regret of an episode
            costs = 0 # cumulative cost during an episode
            saving_regrets = []
            saving_times = []
            while True:
                action = self.agent.act(state)
                next_state, cost, done = self.env.step(action)
                costs += cost

                if (self.agent.t-1) % 100000 == 0:
                    print('experiment = {0}'.format(exp))
                    print('time = {0}'.format(self.agent.t))
                    print('number of visits to state action pairs =')
                    print(self.agent.n)
                    print('________________________')
                self.agent.update(next_state, cost)
                state = deepcopy(next_state)
                if done:
                    regrets.append(costs - self.env.optimal_cost()[0])
                    state = self.env.reset()
                    costs = 0
                    self.agent.reset_episode()
                    if self.agent.k % 10 == 0:  # samples the data at this episode
                        saving_times.append(self.agent.k)
                        saving_regrets.append(np.sum(regrets))
                        regrets = []
                if self.agent.k == self.nb_episodes+1:
                    break 
            res = np.array([saving_times, np.cumsum(saving_regrets)]).T
            if self.save_directory:
                np.savetxt(os.path.join(self.save_directory, str(exp)), res)
            runs.append(res)
            if self.agent.error: break

        if self.save_directory:
            with open(os.path.join(self.save_directory, 't_stats'), 'w') as f:
                json.dump(timeit.stats, f)
        print('Experiments stored in {0}.'.format(self.save_directory))
        return runs


def run(env, agent, storage_counter=1):
    run_dict = {}
    filename = os.path.join(env.__class__.__name__, agent.__class__.__name__ + '_{0}'.format(storage_counter))
    runner = Runner(agent=agent, env=env, nb_runs=nb_runs, nb_episodes=K, filename=filename, seed=SEED)
    run_dict[agent.__class__.__name__] = runner.run()
    return run_dict


def plot(env, agent, storage_counter, run_dict=None):
    alg_storage = {
        agent.__class__.__name__: str(storage_counter),
    }

    legends = {
        'QLearningAgent': 'Q-learning with $\epsilon$-greedy',
        'HQLAgent': 'HQL-SSP',
        'ULCAgent': 'ULC-SSP',
        'SSPBernsteinAgent': 'Bernstein-SSP',
        'UCSSPAgent': 'UC-SSP',
        'SVIAgent': 'SVI-SSP',
        'EBAgent': 'EB-SSP',
    }

    save_directory = 'plots/{}'.format(env.__class__.__name__)
    regret_plot(environment_name=env.__class__.__name__, agents=[agent.__class__.__name__], alg_storage=alg_storage, legends=legends, save_directory=save_directory, run_dict=run_dict)


def main():
    np.random.seed(SEED)
    timeit.reset()

    env = RandomMDPEnv(nb_states=6, nb_actions=2)
    #env = GridWorldEnv(3, 4)

    print('B: {}, T: {}'.format(np.max(env.optimal_cost()), np.max(env.optimal_expected_hitting_time())))
    print('optimal policy:', env.optimal_policy())

    ## random
    #agent = HQLAgent(env, H=5, iota=0.05, refc=4096) # LCB-ADVANTAGE-SSP
    #agent = ULCAgent(env, H=80, iota=2.0)
    #agent = QLearningAgent(env=env, epsilon=0.05)
    #agent = SSPBernsteinAgent(env=env, c=2.0)
    #agent = UCSSPAgent(env=env, c=1.0)
    #agent = SVIAgent(env=env, H=15, iota=0.05)
    agent = EBAgent(env=env, iota=0.05)
    ## gridworld
    #agent = HQLAgent(env, H=5, iota=0.1, refc=4096) # LCB-ADVANTAGE-SSP
    #agent = ULCAgent(env, H=100, iota=1.0)
    #agent = QLearningAgent(env=env, epsilon=0.05)
    #agent = SSPBernsteinAgent(env=env, c=0.5)
    #agent = UCSSPAgent(env=env, c=0.5)
    #agent = SVIAgent(env=env, H=10, iota=0.01)
    #agent = EBAgent(env=env, iota=0.01)
    run_dict = run(env, agent, storage_counter=storage_counter)
    plot(env, agent, storage_counter=storage_counter, run_dict=run_dict)


if __name__ == '__main__':
    main()
