from adv_imitation import DiscountALAgent
from envs import DisBandit, DisCliffWalking
from utils.flags import FLAGS
from utils.Logger import logger
from utils.envs import set_init_state_dis
from utils.utils import get_optimal_policy, estimate_dis_occupancy_measure_from_data
import numpy as np
import os
import yaml
STEP_SIZE = 1e-2
EPS = 1e-12


class DiscountedTableFEM(DiscountALAgent):

    def __init__(self, num_state: int, num_action: int, gamma: float, max_num_iterations: int):
        tmp_occupancy_measure = np.random.random(size=(num_state, num_action))
        normalizer = float(np.sum(tmp_occupancy_measure))
        self._average_occupancy_measure = tmp_occupancy_measure / normalizer
        # self._average_occupancy_measure = np.zeros(shape=(num_state, num_action), dtype=np.float32)
        self._history_occ = np.zeros(shape=(num_state, num_action), dtype=np.float32)
        super(DiscountedTableFEM, self).__init__(num_state, num_action, gamma, max_num_iterations)

    def _projection(self, expert_occupancy_measure: np.ndarray, policy_occupancy_measure: np.ndarray,
                    iterations_now: int):
        vector_a = expert_occupancy_measure - self._average_occupancy_measure
        vector_b = policy_occupancy_measure - self._average_occupancy_measure
        step_size = np.sum(vector_a * vector_b) / np.sum(np.square(vector_b))
        step_two = 1.0 / self.max_num_iterations
        new_average_occupancy_measure = self._average_occupancy_measure + min(step_size, step_two) * vector_b
        # normalization
        new_average_occupancy_measure = np.clip(new_average_occupancy_measure, 0.0, 1.0)
        normalizer = float(np.sum(new_average_occupancy_measure))
        new_average_occupancy_measure = new_average_occupancy_measure / normalizer
        assert np.isclose(float(np.sum(new_average_occupancy_measure)), 1.0)

        reward = expert_occupancy_measure - new_average_occupancy_measure

        return reward, new_average_occupancy_measure

    def train_reward_step(self, expert_occupancy_measure: np.ndarray, policy_occupancy_measure: np.ndarray,
                          iterations_now: int):
        self._history_occ += policy_occupancy_measure / self.max_num_iterations
        self._reward_function, self._average_occupancy_measure = self._projection(expert_occupancy_measure,
                                                                                  policy_occupancy_measure,
                                                                                  iterations_now)

    def get_policy_from_history_occ(self):
        normalizer = np.sum(self._history_occ, axis=1)
        policy = np.zeros(shape=[self.n_state, self.n_action], dtype=np.float32)
        for state in range(self.n_state):
            if normalizer[state] < EPS:
                policy[state] = (1.0 / self.n_action) * np.ones(shape=self.n_action, dtype=np.float32)
            else:
                policy[state] = self._history_occ[state] / normalizer[state]

        assert np.allclose(np.sum(policy, axis=1), np.ones(self.n_state))
        return policy


def train_discounted_fem():
    FLAGS.set_seed()
    FLAGS.freeze()
    ns = FLAGS.env.ns
    na = FLAGS.env.na
    num_data = FLAGS.dis_num_data_dict[FLAGS.env.id]
    max_num_iterations = FLAGS.DisFEM.max_num_iterations
    init_state_dis = set_init_state_dis(env_id=FLAGS.env.id, num_data=num_data, ns=ns)
    from utils.utils import sample_dataset_from_distribution
    value_errors = dict()
    expert_values = dict()
    fem_values = dict()

    # Sample the optimal action index. If the optimal action is determined,
    # we can determine the expert policy and dataset.
    optimal_action = np.random.randint(na)
    expert_policy = get_optimal_policy(ns, na, optimal_action, FLAGS.env.id)
    # In both CliffWalking and Bandit envs, the stationary state-action distribution of expert policy is initial
    # state distribution.
    dataset, uniques_states = sample_dataset_from_distribution(init_state_dis, expert_policy, num_data)
    estimated_expert_occupancy_measure = estimate_dis_occupancy_measure_from_data(ns, na, dataset)
    sampled_mass = init_state_dis[uniques_states]
    missing_mass = 1.0 - float(np.sum(sampled_mass))

    logger.info('Missing mass: %.8f', missing_mass)

    for effective_horizon in range(200, 10000, 100):
        gamma = 1.0 - (1.0 / effective_horizon)
        if FLAGS.env.id == 'CliffWalking':
            env = DisCliffWalking(ns, na, gamma, init_state_dis, optimal_action)
        elif FLAGS.env.id == 'Bandit':
            env = DisBandit(ns, na, gamma, init_state_dis, optimal_action)
        else:
            raise ValueError('Do not support the env {}.'.format(FLAGS.env.id))

        agent = DiscountedTableFEM(ns, na, gamma, max_num_iterations)
        transition_prob = env.transition_probability
        logger.info('Begin training with %d samples', num_data)
        expert_value = env.policy_evaluation(policy=expert_policy)

        for t in range(max_num_iterations):
            policy = agent.get_policy
            policy_occupancy_measure = env.calculate_occupancy_measure_v2(policy)
            agent.train_reward_step(estimated_expert_occupancy_measure, policy_occupancy_measure, t)
            if t % FLAGS.DisFEM.train_policy_freq == 0:
                agent.train_policy_step(transition_prob)

            if t % 200 == 0:
                policy = agent.get_policy
                policy_value = env.policy_evaluation(policy)
                print('Iteration %d: The policy value is %.2f' % (t, policy_value))

        fem_policy = agent.get_policy_from_history_occ()
        fem_value = env.policy_evaluation(policy=fem_policy)
        value_error = expert_value - fem_value

        logger.info('Effective horizon: %d, Discounted factor: %.6f Expert value: %.4f, FEM value: %.4f,'
                    'Value error: %.4f,', effective_horizon, gamma, expert_value, fem_value, value_error)

        expert_values[effective_horizon] = [expert_value]
        fem_values[effective_horizon] = [fem_value]
        value_errors[effective_horizon] = [value_error]

    save_path = os.path.join(FLAGS.log_dir, 'expert_evaluate.yml')
    yaml.dump(expert_values, open(save_path, 'w'), default_flow_style=False)
    save_path = os.path.join(FLAGS.log_dir, 'dis_fem_evaluate.yml')
    yaml.dump(fem_values, open(save_path, 'w'), default_flow_style=False)
    save_path = os.path.join(FLAGS.log_dir, 'value_error_evaluate.yml')
    yaml.dump(value_errors, open(save_path, 'w'), default_flow_style=False)


if __name__ == '__main__':
    train_discounted_fem()



