from BC.main import DiscountedTableBC
from envs.dis_tabular_env import DiscountedTabularEnv
from envs.CliffWalking.CliffWalking import DisCliffWalking
from envs.bandit.bandit_env import DisBandit
import os
from utils.envs.env_utils import set_init_state_dis
from utils.utils import sample_dataset_from_distribution, get_optimal_policy
from utils.Logger import logger
from utils.flags import FLAGS
import numpy as np
import yaml


def dis_dagger_trainer(env: DiscountedTabularEnv, num_data: int, max_num_iterations: int, num_data_ratio: float,
                       expert_policy: np.ndarray) -> np.ndarray:
    """
    The trainer of discounted DAgger.
    Args:
        env: the environment.
        num_data: the number of expert demonstrations in BC.
        max_num_iterations: the number of iterations.
        num_data_ratio: # DAgger / # BC.
        expert_policy: the expert policy that is used to assign action labels.
    Returns:
        policy: the policy trained by DAgger, a numpy array with shape [num_state, num_action].
    """
    num_state, num_action = env.observation_space.n, env.action_space.n
    agent = DiscountedTableBC(num_state, num_action)
    num_data_per_iter = int((num_data * num_data_ratio) / max_num_iterations)
    init_state_dis = env.init_state_distribution
    unique_states_list = []
    for t in range(max_num_iterations):
        sampling_policy = agent.get_policy
        occupancy_measure = env.calculate_occupancy_measure_v2(sampling_policy)
        sampling_state_distribution = np.sum(occupancy_measure, axis=1)
        # assert np.allclose(init_state_dis, sampling_state_distribution)
        dataset, tmp_unique_states = sample_dataset_from_distribution(sampling_state_distribution, expert_policy,
                                                                  num_data_per_iter)
        unique_states_list.append(tmp_unique_states)
        agent.estimate_from_data(dataset)
        if t % 40 == 0:
            tmp_policy = agent.get_policy
            tmp_value = env.policy_evaluation(tmp_policy)
            logger.info('Iteration %d: The policy value of DAgger agent: %.2f', t, tmp_value)

    dagger_policy = agent.get_policy
    unique_states = np.concatenate(unique_states_list)
    unique_states = np.unique(unique_states)
    sampled_mass = init_state_dis[unique_states]
    missing_mass = 1.0 - float(np.sum(sampled_mass))
    logger.info('The missing mass %8f', missing_mass)

    return dagger_policy


def train_discounted_dagger():
    FLAGS.set_seed()
    FLAGS.freeze()
    ns = FLAGS.env.ns
    na = FLAGS.env.na
    num_data = FLAGS.dis_num_data_dict[FLAGS.env.id]
    num_data_ratio = FLAGS.DAgger.num_data_ratio_dict[FLAGS.env.id]
    max_num_iterations = FLAGS.DAgger.max_num_iterations
    init_state_dis = set_init_state_dis(env_id=FLAGS.env.id, num_data=num_data, ns=ns)
    from utils.utils import sample_dataset_from_distribution
    value_errors = dict()
    expert_values = dict()
    dagger_values = dict()

    # Sample the optimal action index. If the optimal action is determined,
    # we can determine the expert policy and dataset.
    optimal_action = np.random.randint(na)
    expert_policy = get_optimal_policy(ns, na, optimal_action, FLAGS.env.id)
    # In both CliffWalking and Bandit envs, the stationary state-action distribution of expert policy is initial
    # state distribution.
    dataset, uniques_states = sample_dataset_from_distribution(init_state_dis, expert_policy, num_data)
    sampled_mass = init_state_dis[uniques_states]
    missing_mass = 1.0 - float(np.sum(sampled_mass))

    logger.info('Missing mass: %.8f', missing_mass)

    for effective_horizon in range(200, 10000, 100):
        gamma = 1.0 - (1.0 / effective_horizon)
        if FLAGS.env.id == 'CliffWalking':
            env = DisCliffWalking(ns, na, gamma, init_state_dis, optimal_action)
        elif FLAGS.env.id == 'Bandit':
            env = DisBandit(ns, na, gamma, init_state_dis, optimal_action)
        else:
            raise ValueError('Do not support the env {}.'.format(FLAGS.env.id))

        dagger_policy = dis_dagger_trainer(env, num_data, max_num_iterations, num_data_ratio, expert_policy)
        dagger_value = env.policy_evaluation(policy=dagger_policy)
        expert_value = env.policy_evaluation(policy=expert_policy)
        value_error = expert_value - dagger_value

        logger.info('Effective horizon: %d, Discounted factor: %.6f Expert value: %.4f, DAgger value: %.4f,'
                    'Value error: %.4f,', effective_horizon, gamma, expert_value, dagger_value, value_error)

        expert_values[effective_horizon] = [expert_value]
        dagger_values[effective_horizon] = [dagger_value]
        value_errors[effective_horizon] = [value_error]

    save_path = os.path.join(FLAGS.log_dir, 'expert_evaluate.yml')
    yaml.dump(expert_values, open(save_path, 'w'), default_flow_style=False)
    save_path = os.path.join(FLAGS.log_dir, 'dagger_evaluate.yml')
    yaml.dump(dagger_values, open(save_path, 'w'), default_flow_style=False)
    save_path = os.path.join(FLAGS.log_dir, 'value_error_evaluate.yml')
    yaml.dump(value_errors, open(save_path, 'w'), default_flow_style=False)


if __name__ == '__main__':
    train_discounted_dagger()



