import sys
import os

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from adv_imitation.al_agent import DiscountALAgent
from envs.CliffWalking.CliffWalking import DisCliffWalking
from envs.bandit.bandit_env import DisBandit
import numpy as np
from utils.flags import FLAGS
from utils.Logger import logger
from utils.utils import sample_dataset_from_distribution, get_optimal_policy, estimate_dis_occupancy_measure_from_data
from utils.envs.env_utils import set_init_state_dis
from typing import List
import os
import yaml
EPS = 1e-8


class DisTableGAIL(DiscountALAgent):
    """
    The class of GAIL agent under discounted infinite horizon MDP.
    """
    def __init__(self, n_state: int, n_action: int, gamma: float, max_num_iterations: int):
        self._average_occupancy_measure = np.zeros(shape=(n_state, n_action), dtype=np.float32)
        super(DisTableGAIL, self).__init__(n_state, n_action, gamma, max_num_iterations)

    def train_reward_step(self, expert_occupancy_measure: np.ndarray, policy_occupancy_measure: np.ndarray,
                          iterations_now: int):

        self._average_occupancy_measure += policy_occupancy_measure / self.max_num_iterations

        optimal_d = expert_occupancy_measure / np.maximum((expert_occupancy_measure + policy_occupancy_measure), EPS)
        self._reward_function = - np.log(np.maximum(1.0-optimal_d, EPS))

    def get_policy_from_occ(self):
        normalizer = np.sum(self._average_occupancy_measure, axis=1)
        policy = np.zeros(shape=[self.n_state, self.n_action], dtype=np.float32)
        for state in range(self.n_state):
            if normalizer[state] < EPS:
                policy[state] = (1.0 / self.n_action) * np.ones(shape=self.n_action, dtype=np.float32)
            else:
                policy[state] = self._average_occupancy_measure[state] / normalizer[state]

        assert np.allclose(np.sum(policy, axis=1), np.ones(self.n_state))
        return policy


def train_dis_gail():
    FLAGS.set_seed()
    FLAGS.algorithm = "GAIL"
    FLAGS.freeze()

    ns = FLAGS.env.ns
    na = FLAGS.env.na
    num_data = FLAGS.dis_num_data_dict[FLAGS.env.id]
    
    # max_num_iterations = FLAGS.GAIL.max_num_iter_dict[FLAGS.env.id]
    max_num_iterations = FLAGS.GAIL.max_num_iterations
    init_state_dis = set_init_state_dis(env_id=FLAGS.env.id, num_data=num_data, ns=ns)

    # Sample the optimal action index. If the optimal action is determined,
    # we can determine the expert policy and dataset.
    optimal_action = np.random.randint(na)
    expert_policy = get_optimal_policy(ns, na, optimal_action, FLAGS.env.id)
    dataset, uniques_states = sample_dataset_from_distribution(init_state_dis, expert_policy, num_data)
    expert_occupancy_measure = estimate_dis_occupancy_measure_from_data(ns, na, dataset)

    sampled_mass = init_state_dis[uniques_states]
    missing_mass = 1.0 - float(np.sum(sampled_mass))
    logger.info('Missing mass: %.8f', missing_mass)

    value_errors = dict()
    expert_values = dict()
    gail_values = dict()

    for effective_horizon in range(200, 10000, 100):
        gamma = 1.0 - float(1.0 / effective_horizon)
        if FLAGS.env.id == 'CliffWalking':
            env = DisCliffWalking(ns, na, gamma, init_state_dis, optimal_action)
        elif FLAGS.env.id == 'Bandit':
            env = DisBandit(ns, na, gamma, init_state_dis, optimal_action)
        else:
            raise ValueError('Do not support the env {}.'.format(FLAGS.env.id))

        expert_value = env.policy_evaluation(expert_policy)
        expert_values[effective_horizon] = [expert_value]

        transition_prob = env.transition_probability
        dis_gail_agent = DisTableGAIL(ns, na, gamma, max_num_iterations)
        logger.info('Begin training in effective horizon = %d', effective_horizon)

        for t in range(max_num_iterations):
            policy = dis_gail_agent.get_policy
            policy_occupancy_measure = env.calculate_occupancy_measure_v2(policy)
            dis_gail_agent.train_reward_step(expert_occupancy_measure, policy_occupancy_measure, t)
            if t % FLAGS.GAIL.train_policy_freq == 0:
                dis_gail_agent.train_policy_step(transition_prob)

            if t % 200 == 0:
                policy = dis_gail_agent.get_policy
                policy_value = env.policy_evaluation(policy)
                logger.info('The policy value at iterations %d: %.4f', t, policy_value)
                occupancy_measure_loss = float(np.sum(np.abs(expert_occupancy_measure - policy_occupancy_measure)))
                logger.info('Iteration %d:, Occupancy measure distance: %.3f', t, occupancy_measure_loss)
        if FLAGS.GAIL.is_average and FLAGS.env.id == 'CliffWalking':
            gail_policy = dis_gail_agent.get_policy_from_occ()
        else:
            gail_policy = dis_gail_agent.get_policy
        gail_value = env.policy_evaluation(gail_policy)
        value_error = expert_value - gail_value
        gail_values[effective_horizon] = [gail_value]
        value_errors[effective_horizon] = [value_error]
        logger.info('Effective horizon: %d, Discounted factor: %.6f Expert value: %.4f, GAIL value: %.4f,'
                    'Value error: %.4f,', effective_horizon, gamma, expert_value, gail_value, value_error)

    save_path = os.path.join(FLAGS.log_dir, 'expert_evaluate.yml')
    yaml.dump(expert_values, open(save_path, 'w'), default_flow_style=False)
    save_path = os.path.join(FLAGS.log_dir, 'gail_evaluate.yml')
    yaml.dump(gail_values, open(save_path, 'w'), default_flow_style=False)
    save_path = os.path.join(FLAGS.log_dir, 'value_error_evaluate.yml')
    yaml.dump(value_errors, open(save_path, 'w'), default_flow_style=False)


if __name__ == '__main__':
    train_dis_gail()

