import sys
sys.path.append('.')
from envs.CliffWalking.CliffWalking import CliffWalking, DisCliffWalking
from envs.bandit.bandit_env import Bandit, DisBandit
import os
import numpy as np
from typing import List
from utils.utils import sample_dataset, get_optimal_policy
from utils.flags import FLAGS
from utils.Logger import logger
from utils.envs.env_utils import set_init_state_dis
import yaml


class TableBC(object):

    def __init__(self, dim_state: int, dim_action: int, max_episode_steps: int) -> None:
        self.dim_state = dim_state
        self.dim_action = dim_action
        self.max_episode_steps = max_episode_steps
        tmp = np.random.random(size=[self.dim_state, self.dim_action, self.max_episode_steps])
        self.policy = tmp / np.sum(tmp, axis=1, keepdims=True)

    def estimate_from_data(self, dataset: List[tuple]):
        """
        Args:
            dataset: expert demonstrations, [(state, action, step)]

        """
        counts = np.zeros(shape=(self.dim_state, self.dim_action, self.max_episode_steps), dtype=np.float32)
        for each_data in dataset:
            state, action, step = each_data[0], each_data[1], each_data[2]
            counts[state, action, step] += 1.0
        for state in range(self.dim_state):
            for h in range(self.max_episode_steps):
                num_state_h = np.sum(counts[state, :, h])
                if num_state_h > 0:
                    self.policy[state, :, h] = counts[state, :, h] / num_state_h

    @property
    def get_policy(self):
        return self.policy.copy()


class DiscountedTableBC(object):
    """
    The class for BC agent in the infinite-horizon discounted MDP. Different from episodic MDP, BC agent here follows a
    stationary policy.
    """

    def __init__(self, num_state: int, num_action: int):

        self.ns = num_state
        self.na = num_action

        # record the number of history samples.
        self.state_counter = np.zeros(shape=num_state, dtype=np.float32)
        self.policy = np.random.random(size=(self.ns, self.na))
        normalizer = np.sum(self.policy, axis=1, keepdims=True)
        self.policy = self.policy / normalizer

    def estimate_from_data(self, dataset: List):
        """
        Learn the stationary policy from dataset.
        Args:
            dataset: expert demonstrations, [(s, a)] where (s, a) follows the discounted stationary distribution.

        """
        counts = np.zeros(shape=[self.ns, self.na], dtype=np.int32)
        for each_tuple in dataset:
            state, action = each_tuple[0], each_tuple[1]
            counts[state, action] += 1

        # for state in range(self.ns):
        #     if np.count_nonzero(counts[state]) > 1:
        #         raise ValueError('Do not support stochastic expert policy.')

        normalizer = np.sum(counts, axis=1)
        for state in range(self.ns):
            if normalizer[state] > 0:
                history_counts = self.state_counter[state].copy()
                history_policy = self.policy[state].copy()

                total_counts = history_counts + normalizer[state]
                new_policy = counts[state, :] / normalizer[state]
                self.policy[state, :] = (history_counts / total_counts) * history_policy +\
                                        (1.0 - history_counts / total_counts) * new_policy

                # update the history counter.
                self.state_counter[state] += normalizer[state]

    @property
    def get_policy(self):
        return self.policy.copy()


def train_discounted_bc():
    FLAGS.set_seed()
    FLAGS.freeze()
    ns = FLAGS.env.ns
    na = FLAGS.env.na
    num_data = FLAGS.dis_num_data_dict[FLAGS.env.id]
    num_data = 10000
    init_state_dis = set_init_state_dis(env_id=FLAGS.env.id, num_data=num_data, ns=ns)
    from utils.utils import sample_dataset_from_distribution
    value_errors = dict()
    expert_values = dict()
    bc_values = dict()

    # Sample the optimal action index. If the optimal action is determined,
    # we can determine the expert policy and dataset.
    optimal_action = np.random.randint(na)
    expert_policy = get_optimal_policy(ns, na, optimal_action, FLAGS.env.id)
    # In both CliffWalking and Bandit envs, the stationary state-action distribution of expert policy is initial
    # state distribution.
    dataset, uniques_states = sample_dataset_from_distribution(init_state_dis, expert_policy, num_data)
    sampled_mass = init_state_dis[uniques_states]
    missing_mass = 1.0 - float(np.sum(sampled_mass))

    logger.info('Missing mass: %.8f', missing_mass)
    for effective_horizon in range(200, 10000, 100):
        gamma = 1.0 - (1.0 / effective_horizon)
        if FLAGS.env.id == 'CliffWalking':
            env = DisCliffWalking(ns, na, gamma, init_state_dis, optimal_action)
        elif FLAGS.env.id == 'Bandit':
            env = DisBandit(ns, na, gamma, init_state_dis, optimal_action)
        else:
            raise ValueError('Do not support the env {}.'.format(FLAGS.env.id))

        agent = DiscountedTableBC(ns, na)
        agent.estimate_from_data(dataset)
        bc_policy = agent.get_policy
        bc_value = env.policy_evaluation(policy=bc_policy)
        expert_value = env.policy_evaluation(policy=expert_policy)
        value_error = expert_value - bc_value

        logger.info('Effective horizon: %d, Discounted factor: %.6f Expert value: %.4f, BC value: %.4f,'
                    'Value error: %.4f,', effective_horizon, gamma, expert_value, bc_value, value_error)

        expert_values[effective_horizon] = [expert_value]
        bc_values[effective_horizon] = [bc_value]
        value_errors[effective_horizon] = [value_error]

    save_path = os.path.join(FLAGS.log_dir, 'expert_evaluate.yml')
    yaml.dump(expert_values, open(save_path, 'w'), default_flow_style=False)
    save_path = os.path.join(FLAGS.log_dir, 'bc_evaluate.yml')
    yaml.dump(bc_values, open(save_path, 'w'), default_flow_style=False)
    save_path = os.path.join(FLAGS.log_dir, 'value_error_evaluate.yml')
    yaml.dump(value_errors, open(save_path, 'w'), default_flow_style=False)



def main():
    FLAGS.set_seed()
    FLAGS.freeze()
    ns = FLAGS.env.ns
    na = FLAGS.env.na
    num_data = FLAGS.num_data_dict[FLAGS.env.id]

    init_state_dis = set_init_state_dis(env_id=FLAGS.env.id, num_data=num_data, ns=ns)

    value_errors = dict()
    expert_values = dict()
    bc_values = dict()

    for max_episode_steps in range(200, 3400, 200):
        if FLAGS.env.id == 'CliffWalking':
            env = CliffWalking(ns, na, init_state_dis, max_episode_steps)
        elif FLAGS.env.id == 'Bandit':
            env = Bandit(ns, na, init_state_dis, max_episode_steps)
        else:
            raise ValueError('The env {} is not supported.'.format(FLAGS.env.id))

        expert_policy = env.get_optimal_policy()
        expert_value = env.policy_evaluation(expert_policy)
        dataset = sample_dataset(env, expert_policy, num_data, is_deterministic=False)

        bc_agent = TableBC(ns, na, max_episode_steps)
        bc_agent.estimate_from_data(dataset)
        bc_policy = bc_agent.get_policy
        bc_value = env.policy_evaluation(policy=bc_policy)
        value_error = expert_value - bc_value

        logger.info('Max episode steps: %d, Expert value: %.4f, BC value: %.4f, Value error: %.4f',
                    max_episode_steps, expert_value, bc_value, value_error)

        expert_values[max_episode_steps] = [expert_value]
        bc_values[max_episode_steps] = [bc_value]
        value_errors[max_episode_steps] = [value_error]

    save_path = os.path.join(FLAGS.log_dir, 'expert_evaluate.yml')
    yaml.dump(expert_values, open(save_path, 'w'), default_flow_style=False)
    save_path = os.path.join(FLAGS.log_dir, 'bc_evaluate.yml')
    yaml.dump(bc_values, open(save_path, 'w'), default_flow_style=False)
    save_path = os.path.join(FLAGS.log_dir, 'value_error_evaluate.yml')
    yaml.dump(value_errors, open(save_path, 'w'), default_flow_style=False)


def train_with_diff_num_samples():
    """
    Train BC agent with different number of samples.
    """

    FLAGS.set_seed()
    FLAGS.freeze()
    ns = FLAGS.env.ns
    na = FLAGS.env.na
    max_episode_steps = FLAGS.env.max_episode_steps

    value_errors = dict()
    expert_values = dict()
    bc_values = dict()
    # For non-uniform initial state distribution, the data range is [10000, 100000] with step size = 2000.

    for num_data in range(100000, 1000000, 20000):
        num_data = 10000
        init_state_dis = set_init_state_dis(FLAGS.env.id, num_data, ns, dis_type=FLAGS.env.init_dist_type)
        if FLAGS.env.id == 'CliffWalking':
            env = CliffWalking(ns, na, init_state_dis, max_episode_steps)
        elif FLAGS.env.id == 'Bandit':
            env = Bandit(ns, na, init_state_dis, max_episode_steps)
        else:
            raise ValueError('Env %s is not supported.' % FLAGS.env.id)
        expert_policy = env.get_optimal_policy()
        expert_value = env.policy_evaluation(expert_policy)
        dataset = sample_dataset(env, expert_policy, num_data, is_deterministic=False)

        bc_agent = TableBC(ns, na, max_episode_steps)
        bc_agent.estimate_from_data(dataset)
        bc_policy = bc_agent.get_policy
        bc_value = env.policy_evaluation(policy=bc_policy)
        value_error = expert_value - bc_value

        logger.info('The number of samples: %d, Expert value: %.4f, BC value: %.4f, Value error: %.4f',
                    num_data, expert_value, bc_value, value_error)

        expert_values[num_data] = [expert_value]
        bc_values[num_data] = [bc_value]
        value_errors[num_data] = [value_error]

    save_path = os.path.join(FLAGS.log_dir, 'expert_evaluate.yml')
    yaml.dump(expert_values, open(save_path, 'w'), default_flow_style=False)
    save_path = os.path.join(FLAGS.log_dir, 'bc_evaluate.yml')
    yaml.dump(bc_values, open(save_path, 'w'), default_flow_style=False)
    save_path = os.path.join(FLAGS.log_dir, 'value_error_evaluate.yml')
    yaml.dump(value_errors, open(save_path, 'w'), default_flow_style=False)


if __name__ == '__main__':
    # main()
    # train_with_diff_num_samples()
    train_discounted_bc()


