import gym
import pandas as pd
from pathlib import Path
import os

# os.chdir(os.path.abspath(os.path.join(os.path.dirname("__file__"),os.path.pardir)))
system_add=os.getcwd()


from component.obs_decomp import obs_decomp_alg
from component.obs_ablation import obs_abl
from agent.sarsa_agent import sarsa



def run(miss_num, miss_level,chosen):
    '''
    0<=miss_num:int<=12 is the number of the observation ablation
    miss_level in [0,1,2] is the level of observation ablation
    '''
    miss = obs_abl()
    miss_index, rest_obs = miss.row_random(miss_level, miss_num, chosen)
    # chosen.extend(miss_index)
    chosen=miss_index
    #print(chosen)
    print(miss_index)

    # create Taxi environment
    # env = gym.make('Taxi-v3')
    env = gym.make('CliffWalking-v0')

    state_size = env.observation_space.n
    action_size = env.action_space.n

    alg = obs_decomp_alg()
    agent = sarsa(state_size, action_size)

    # training variables
    num_episodes = 2000
    max_steps = 50  # per episode

    # training
    for episode in range(num_episodes):
        # print('episode' + str(episode))
        # reset the environment
        state = env.reset()

        # check and update the state
        state = miss.state_judge(state, miss_index, rest_obs)
        # env.render()
        done = False

        # take a action
        action = agent.take_action(env.action_space.sample(), state)

        for s in range(max_steps):

            # if done, finish episode
            if done:
                break

            # get the observe reward
            new_state, reward, done, info = env.step(action)

            # check and update the new state
            new_state = miss.state_judge(new_state, miss_index, rest_obs)

            # record
            alg.history_record(state, new_state, action, reward)

            # take a action
            next_action = agent.take_action(env.action_space.sample(), new_state)

            # update q_table
            agent.update_qtable(state, new_state, action, next_action, reward)

            # Update to the new state and action
            state = new_state
            action = next_action

        # update epsilon
        agent.decrease_epsilon(episode)

    # print(f"Training completed over {num_episodes} episodes")
    # input("Press Enter to watch trained agent...")

    # restart the maximized q-value episode
    state = env.reset()

    # check and update the state
    state = miss.state_judge(state, miss_index, rest_obs)
    # env.render()
    done = False

    # take a action
    action = agent.take_action(env.action_space.sample(), state)
    rewards = 0

    for s in range(max_steps):

        if done:
            break

        # get the observe reward
        new_state, reward, done, info = env.step(action)

        # check and update the new state
        new_state = miss.state_judge(new_state, miss_index, rest_obs)

        rewards += reward

        # take a action
        next_action = agent.get_max_action(new_state)

        # env.render()
        # print(f"score: {rewards}")
        # Update to the new state and action
        state = new_state
        action = next_action

    env.close()
    return alg, s, rewards, chosen


def main():
    for count in range(10):
        print('count:', str(count))
        result = {}
        # k is the ablation level
        for k in range(3):
            result[k] = {}
            # j is the ablation number
            for j in range(12):
                if j == 0:
                    chosen=[]
                alg, s, rewards,chosen = run(j, k, chosen)
                obs_act = list(map(lambda x: (x[0], x[2]), alg.history))
                obs_act_uni = []

                for i in range(len(obs_act)):
                    if obs_act[i] not in obs_act_uni:
                        obs_act_uni.append(obs_act[i])
                ent = 0
                boolean_weight_lambda = 0
                for i in range(len(obs_act_uni)):
                    lambda_list = alg.lambda_list_gen(alg.history, obs_act_uni[i][0], obs_act_uni[i][1])
                    # count the accumulative entropy
                    ent_temp, boolean_weight_lambda_temp = alg.dimension_entropy(lambda_list)
                    ent+=ent_temp
                    boolean_weight_lambda=boolean_weight_lambda+boolean_weight_lambda_temp-1
                    if ent_temp != 0:
                        print(str(obs_act_uni[i][0]), str(obs_act_uni[i][0]), str(obs_act_uni[i][1]))
                # add the value to result dict
                result[k][j] = {'lambda': ent, 'step': s, 'rewards': rewards,'boolean_weight_lambda':boolean_weight_lambda}
        print(pd.DataFrame(result))
        pd.DataFrame(result).to_csv(system_add+'/result/' + Path(__file__).stem + '_' +str(count) +'.csv')

if __name__=="__main__":
    main()

