from adv_imitation import DiscountALAgent
from envs.CliffWalking.CliffWalking import CliffWalking, DisCliffWalking
from envs.bandit.bandit_env import Bandit, DisBandit
import math
import numpy as np
import os
from typing import List
from utils.utils import sample_dataset, get_optimal_policy, estimate_dis_occupancy_measure_from_data,\
    estimate_occupancy_measure_from_data

from utils.flags import FLAGS
from utils.Logger import logger
from utils.envs.env_utils import set_init_state_dis
import yaml
EPS = 1e-8


class TableGTAL(object):

    def __init__(self, n_state: int, n_action: int, max_episode_steps: int) -> None:
        self.n_state = n_state
        self.n_action = n_action
        self.max_episode_steps = max_episode_steps
        tmp = np.random.random(size=(self.n_state, self.n_action, self.max_episode_steps))
        tmp = tmp / np.sum(tmp, axis=1, keepdims=True)
        self._policy = tmp
        self._reward_function = np.random.uniform(low=-1.0, high=1.0,
                                                  size=(self.n_state, self.n_action, self.max_episode_steps))

    @property
    def get_reward_function(self):
        return self._reward_function.copy()

    @property
    def get_policy(self):
        return self._policy.copy()

    def _generate_greedy_policy(self, q_functions: np.ndarray):
        M, N, H = self.n_state, self.n_action, self.max_episode_steps
        greedy_policy = np.zeros(shape=(M, N, H), dtype=np.float32)
        greedy_action = np.argmax(q_functions, axis=1)

        for state in range(M):
            action_dis = np.zeros(shape=(N, H), dtype=np.float32)
            for time_step in range(H):
                action_dis[greedy_action[state, time_step], time_step] = 1.0
            greedy_policy[state, :, :] = action_dis

        return greedy_policy

    def _value_iteration(self, transition_probability: np.ndarray):

        M, N, H = self.n_state, self.n_action, self.max_episode_steps
        V_functions = np.zeros((M, H+1))
        Q_functions = np.zeros((M, N, H))
        transition_prob = transition_probability.copy()
        reward_func = self._reward_function.copy()
        for h in range(H-1, -1, -1):
            V_next = V_functions[:, h + 1]
            V_next = np.reshape(V_next, newshape=(1, 1, M))
            tmp_Q_h = np.sum(transition_prob * V_next, axis=-1) + reward_func[:, :, h]
            tmp_V_h = np.max(tmp_Q_h, axis=1)
            Q_functions[:, :, h] = tmp_Q_h
            V_functions[:, h] = tmp_V_h

        opt_policy = self._generate_greedy_policy(Q_functions)
        return opt_policy

    def train_policy_step(self, transition_probability: np.ndarray):

        self._policy = self._value_iteration(transition_probability)

    def _projected_gradient_descent(self, expert_occupancy_measure: np.ndarray,
                                    policy_occupancy_measure: np.ndarray, iterations_now: int):

        expert_occupancy_measure = expert_occupancy_measure.copy()
        policy_occupancy_measure = policy_occupancy_measure.copy()
        step_size = np.sqrt(np.divide(2 * self.n_state * self.n_action, max(iterations_now, 1)))
        step_size = np.clip(step_size, a_max=1.0, a_min=EPS)
        grad = policy_occupancy_measure - expert_occupancy_measure
        old_reward_function = self.get_reward_function
        reward_function = np.clip(old_reward_function - step_size * grad, a_max=1.0, a_min=-1.0)

        return reward_function

    def train_reward_step(self, expert_occupancy_measure: np.ndarray,
                          policy_occupancy_measure: np.ndarray, iterations_now: int):

        self._reward_function = self._projected_gradient_descent(expert_occupancy_measure, policy_occupancy_measure,
                                                                 iterations_now)


class DiscountedTableGTAL(DiscountALAgent):

    def __init__(self, num_state: int, num_action: int, gamma: float, max_num_iterations: int, reward_opt_type: str):
        assert reward_opt_type in ['PG', 'MW'], 'The optimization method %s is not supported.' % reward_opt_type
        self.reward_optimization_type = reward_opt_type
        self._average_occupancy_measure = np.zeros(shape=(num_state, num_action), dtype=np.float32)

        super(DiscountedTableGTAL, self).__init__(num_state, num_action, gamma, max_num_iterations)

    def _projected_gradient_descent(self, expert_occupancy_measure: np.ndarray,
                                    policy_occupancy_measure: np.ndarray, iterations_now: int):

        expert_occupancy_measure = expert_occupancy_measure.copy()
        policy_occupancy_measure = policy_occupancy_measure.copy()
        step_size = np.sqrt(np.divide(2 * self.n_state * self.n_action, max(iterations_now, 1)))
        step_size = np.clip(step_size, a_max=1.0, a_min=EPS)
        grad = policy_occupancy_measure - expert_occupancy_measure
        old_reward_function = self.get_reward_function
        reward_function = np.clip(old_reward_function - step_size * grad, a_max=1.0, a_min=-1.0)

        return reward_function

    def train_reward_step(self, expert_occupancy_measure: np.ndarray,
                          policy_occupancy_measure: np.ndarray, iterations_now: int):

        # update the average occupancy measure.
        if iterations_now > 0:
            self._average_occupancy_measure += policy_occupancy_measure / self.max_num_iterations

        # update the reward function.
        if self.reward_optimization_type == 'MW':
            self._reward_function = self._multiply_weight(expert_occupancy_measure, policy_occupancy_measure)
        elif self.reward_optimization_type == 'PG':
            self._reward_function = self._projected_gradient_descent(expert_occupancy_measure, policy_occupancy_measure,
                                                                     iterations_now)
        else:
            raise ValueError('The optimization method %s is not supported.' % self.reward_optimization_type)

    def _multiply_weight(self, expert_occupancy_measure: np.ndarray, policy_occupancy_measure: np.ndarray):
        step_size = 1.0 / (1.0 + math.sqrt(2.0 * math.log(self.n_state * self.n_action) / self.max_num_iterations))
        divergence = (1.0 + policy_occupancy_measure - expert_occupancy_measure) / 2.0
        weight_matrix = np.exp(np.log(step_size) * divergence)
        old_reward_function = self.get_reward_function
        reward_function = weight_matrix * old_reward_function
        normalizer = float(np.sum(reward_function))
        reward_function = reward_function / normalizer
        return reward_function

    def get_policy_from_occ(self):
        normalizer = np.sum(self._average_occupancy_measure, axis=1)
        policy = np.zeros(shape=[self.n_state, self.n_action], dtype=np.float32)
        for state in range(self.n_state):
            if normalizer[state] < EPS:
                policy[state] = (1.0 / self.n_action) * np.ones(shape=self.n_action, dtype=np.float32)
            else:
                policy[state] = self._average_occupancy_measure[state] / normalizer[state]

        assert np.allclose(np.sum(policy, axis=1), np.ones(self.n_state))
        return policy


def train_discounted_gtal():
    FLAGS.set_seed()
    FLAGS.freeze()
    ns = FLAGS.env.ns
    na = FLAGS.env.na
    num_data = FLAGS.dis_num_data_dict[FLAGS.env.id]
    max_num_iterations = FLAGS.DisGTAL.max_num_iterations
    init_state_dis = set_init_state_dis(env_id=FLAGS.env.id, num_data=num_data, ns=ns)
    from utils.utils import sample_dataset_from_distribution
    value_errors = dict()
    expert_values = dict()
    gtal_values = dict()

    # Sample the optimal action index. If the optimal action is determined,
    # we can determine the expert policy and dataset.
    optimal_action = np.random.randint(na)
    expert_policy = get_optimal_policy(ns, na, optimal_action, FLAGS.env.id)
    # In both CliffWalking and Bandit envs, the stationary state-action distribution of expert policy is initial
    # state distribution.
    dataset, uniques_states = sample_dataset_from_distribution(init_state_dis, expert_policy, num_data)
    estimated_expert_occupancy_measure = estimate_dis_occupancy_measure_from_data(ns, na, dataset)
    sampled_mass = init_state_dis[uniques_states]
    missing_mass = 1.0 - float(np.sum(sampled_mass))

    logger.info('Missing mass: %.8f', missing_mass)

    for effective_horizon in range(200, 10000, 100):
        gamma = 1.0 - (1.0 / effective_horizon)
        if FLAGS.env.id == 'CliffWalking':
            env = DisCliffWalking(ns, na, gamma, init_state_dis, optimal_action)
        elif FLAGS.env.id == 'Bandit':
            env = DisBandit(ns, na, gamma, init_state_dis, optimal_action)
        else:
            raise ValueError('Do not support the env {}.'.format(FLAGS.env.id))

        agent = DiscountedTableGTAL(ns, na, gamma, max_num_iterations, FLAGS.DisGTAL.reward_opt_type)
        transition_prob = env.transition_probability
        logger.info('Begin training with %d samples', num_data)

        for t in range(max_num_iterations):
            policy = agent.get_policy
            policy_occupancy_measure = env.calculate_occupancy_measure_v2(policy)
            agent.train_reward_step(estimated_expert_occupancy_measure, policy_occupancy_measure, t)
            if t % FLAGS.DisGTAL.train_policy_freq == 0:
                agent.train_policy_step(transition_prob)

            if t % 200 == 0:
                policy = agent.get_policy
                policy_value = env.policy_evaluation(policy)
                print('Iteration %d: The policy value is %.2f' % (t, policy_value))
        if FLAGS.DisGTAL.is_average and FLAGS.env.id == 'CliffWalking':
            gtal_final_policy = agent.get_policy_from_occ()
        else:
            gtal_final_policy = agent.get_policy
        gtal_value = env.policy_evaluation(policy=gtal_final_policy)
        expert_value = env.policy_evaluation(policy=expert_policy)
        value_error = expert_value - gtal_value
        logger.info('Effective horizon: %d, Discounted factor: %.6f Expert value: %.4f, GTAL value: %.4f,'
                    'Value error: %.4f', effective_horizon, gamma, expert_value, gtal_value,
                    value_error)

        expert_values[effective_horizon] = [expert_value]
        gtal_values[effective_horizon] = [gtal_value]
        value_errors[effective_horizon] = [value_error]

    save_path = os.path.join(FLAGS.log_dir, 'expert_evaluate.yml')
    yaml.dump(expert_values, open(save_path, 'w'), default_flow_style=False)
    save_path = os.path.join(FLAGS.log_dir, 'dis_gtal_evaluate.yml')
    yaml.dump(gtal_values, open(save_path, 'w'), default_flow_style=False)
    save_path = os.path.join(FLAGS.log_dir, 'value_error_evaluate.yml')
    yaml.dump(value_errors, open(save_path, 'w'), default_flow_style=False)


def train_with_diff_num_samples():
    """
    Train GTAL agent with different number of samples.
    """

    FLAGS.set_seed()
    FLAGS.freeze()
    ns = FLAGS.env.ns
    na = FLAGS.env.na
    max_episode_steps = FLAGS.env.max_episode_steps
    max_num_iterations = FLAGS.GTAL.max_num_iterations

    value_errors = dict()
    expert_values = dict()
    gtal_values = dict()
    distribution_errors = dict()

    for num_data in range(100000, 1000000, 20000):
        init_state_dis = set_init_state_dis(FLAGS.env.id, num_data, FLAGS.env.ns)
        if FLAGS.env.id == 'CliffWalking':
            env = CliffWalking(ns, na, init_state_dis, max_episode_steps)
        elif FLAGS.env.id == 'Bandit':
            env = Bandit(ns, na, init_state_dis, max_episode_steps)
        else:
            raise ValueError('Env %s is not supported.' % FLAGS.env.id)
        expert_policy = env.get_optimal_policy()
        expert_value = env.policy_evaluation(expert_policy)
        dataset = sample_dataset(env, expert_policy, num_data, is_deterministic=False)
        expert_occupancy_measure = estimate_occupancy_measure_from_data(ns, na, max_episode_steps, dataset)
        true_expert_occupancy_measure = env.calculate_occupancy_measure(expert_policy)
        tv_loss = float(np.sum(np.abs(true_expert_occupancy_measure - expert_occupancy_measure)))
        distribution_errors[num_data] = [tv_loss]
        logger.info('The number of samples: %d, The distribution error: %.4f.', num_data, tv_loss)

        transition_prob = env.transition_probability
        gtal_agent = TableGTAL(ns, na, max_episode_steps)
        logger.info('Begin training with %d samples', num_data)

        for t in range(max_num_iterations):

            policy = gtal_agent.get_policy
            policy_occupancy_measure = env.calculate_occupancy_measure(policy)
            gtal_agent.train_reward_step(expert_occupancy_measure, policy_occupancy_measure, t)

            gtal_agent.train_policy_step(transition_prob)
            if t % 20 == 0:
                policy = gtal_agent.get_policy
                policy_value = env.policy_evaluation(policy)
                print('Iteration %d: The policy value is %.2f' % (t, policy_value))

        gtal_policy = gtal_agent.get_policy
        gtal_value = env.policy_evaluation(gtal_policy)
        value_error = expert_value - gtal_value
        logger.info('The number of samples: %d, Expert value: %.4f, GTAL value: %.4f, Value error: %.4f',
                    num_data, expert_value, gtal_value, value_error)

        expert_values[num_data] = [expert_value]
        gtal_values[num_data] = [gtal_value]
        value_errors[num_data] = [value_error]

    save_path = os.path.join(FLAGS.log_dir, 'expert_evaluate.yml')
    yaml.dump(expert_values, open(save_path, 'w'), default_flow_style=False)
    save_path = os.path.join(FLAGS.log_dir, 'gtal_evaluate.yml')
    yaml.dump(gtal_values, open(save_path, 'w'), default_flow_style=False)
    save_path = os.path.join(FLAGS.log_dir, 'value_error_evaluate.yml')
    yaml.dump(value_errors, open(save_path, 'w'), default_flow_style=False)
    save_path = os.path.join(FLAGS.log_dir, 'distribution_error_evaluate.yml')
    yaml.dump(distribution_errors, open(save_path, 'w'), default_flow_style=False)


def main():
    FLAGS.set_seed()
    FLAGS.freeze()

    ns = FLAGS.env.ns
    na = FLAGS.env.na
    num_data = FLAGS.num_data_dict[FLAGS.env.id]
    max_num_iterations = FLAGS.GTAL.max_num_iterations
    init_state_dis = set_init_state_dis(FLAGS.env.id, num_data, ns)

    value_errors = dict()
    expert_values = dict()
    gtal_values = dict()

    for max_episode_steps in range(200, 3400, 200):
        if FLAGS.env.id == 'CliffWalking':
            env = CliffWalking(ns, na, init_state_dis, max_episode_steps)
        elif FLAGS.env.id == 'Bandit':
            env = Bandit(ns, na, init_state_dis, max_episode_steps)
        else:
            raise ValueError('Env %s is not supported.' % FLAGS.env.id)
        expert_policy = env.get_optimal_policy()
        expert_value = env.policy_evaluation(expert_policy)
        expert_values[max_episode_steps] = [expert_value]

        transition_porb = env.transition_probability
        dataset = sample_dataset(env, expert_policy, num_data, is_deterministic=True)
        expert_occupancy_measure = estimate_occupancy_measure_from_data(ns, na, max_episode_steps, dataset)

        gtal_agent = TableGTAL(ns, na, max_episode_steps)
        logger.info('Begin training in max episodes steps = %d', max_episode_steps)

        for t in range(max_num_iterations):

            gtal_agent.train_policy_step(transition_porb)
            policy = gtal_agent.get_policy
            policy_occupancy_measure = env.calculate_occupancy_measure(policy)

            policy_value = env.policy_evaluation(policy)
            if t % 20 == 0:
                logger.info('The policy value at iterations %d: %.4f' % (t, policy_value))
            gtal_agent.train_reward_step(expert_occupancy_measure, policy_occupancy_measure, t)

        policy = gtal_agent.get_policy
        gtal_policy_value = env.policy_evaluation(policy)
        gtal_values[max_episode_steps] = [gtal_policy_value]
        value_error = expert_value - gtal_policy_value
        value_errors[max_episode_steps] = [value_error]
        logger.info('Max episode steps: %d, Expert value: %.4f, GTAL value: %.4f, Value error: %.4f',
                    max_episode_steps, expert_value, gtal_policy_value, value_error)

    save_path = os.path.join(FLAGS.log_dir, 'expert_evaluate.yml')
    yaml.dump(expert_values, open(save_path, 'w'), default_flow_style=False)
    save_path = os.path.join(FLAGS.log_dir, 'gtal_evaluate.yml')
    yaml.dump(gtal_values, open(save_path, 'w'), default_flow_style=False)
    save_path = os.path.join(FLAGS.log_dir, 'value_error_evaluate.yml')
    yaml.dump(value_errors, open(save_path, 'w'), default_flow_style=False)


if __name__ == '__main__':
    train_discounted_gtal()
    # train_with_diff_num_samples()
    # main()









