import numpy as np
from domains import Simple_Grid, Taxi_Domain
from Learning import *
import random 
from log import Log_experiments
from abstraction import Abstraction
import hyper_param


for trial in range (1,2):
    # ____________ main Parameters ___________________________
    random.seed(23*trial)
    approach_name = 'q'
    map_name = hyper_param.map_name
    file_name = map_name + "_" + approach_name  + "_" + str(trial)
    step_max = hyper_param.step_max
    episodes = hyper_param.episode_max
    env = hyper_param.env
    #_________________________________________________________


    agent_con_qlearning = qlearning (env, state_size = env._state_size, action_size = env._action_size)
    agent = agent_con_qlearning

    log = Log_experiments()
    agent._acc_reward_data["Num_episodes"] = list()
    agent._acc_reward_data["Cumulative_rewards"] = list()
    for i in range (episodes):
        state = env.reset()
        done = False
        reward = 0
        epoch = 0
        while (not done) and (epoch < step_max):
            env.update_visited(state)
            action = agent.policy (state)
            new_state, r, done, success = env.step (action) 
            agent.train (state, new_state, action, None, r)
            state = new_state
            reward += r
            epoch += 1
        agent.decay()
        log.log_episode(reward, success, epoch)
    
        print ("_______________________________")
        print ("episode: " + str(i) + "\t" + "reward: " + str (reward) + "\t" + "epochs: " + str(epoch) 
              + "\t" + "epsilon: " + str(round(agent._epsilon,3)) + "\t" + "success: " + str(success))
        if (i % 10 == 0 and i>0):
            total_reward_list = []
            for j in range (10):
                state = env.reset()
                done = False
                total_reward = 0
                epoch = 0
                while (not done) and (epoch < step_max):
                    env.update_visited(state)
                    action = agent.policy_greedy (state)
                    new_state, r, done, success = env.step (action) 
                    state = new_state
                    total_reward += r
                    epoch += 1
                total_reward_list.append(total_reward)
            agent._acc_reward_data["Num_episodes"].append(i)
            agent._acc_reward_data["Cumulative_rewards"].append(total_reward_list)


    log.save_execution (file_name)
    log.save_acc_rewards(file_name, agent._acc_reward_data)
    log.plot_learning(500, "success")