from envs.tabular_env import TabularEnv
import gym
import numpy as np
from utils.utils import evaluate


def test_policy_evaluation(env: gym.Env, env_type='Episodic'):
    """

    Args:
        env: the environment to test.
        env_type: 'Episodic' or 'Discounted'
    """
    import copy

    opt_policy = env.get_optimal_policy()
    ns = env.observation_space.n
    na = env.action_space.n

    for t in range(10):
        sub_opt_policy = copy.deepcopy(opt_policy)
        rand_dis = np.random.random(size=na)
        rand_dis = rand_dis / np.sum(rand_dis)

        noised_state = np.random.choice(ns)

        # Under the episodic environment, we test non-stationary policy. Otherwise, we test stationary policy.
        if env_type == 'Episodic':
            sub_opt_policy[noised_state, :, :] = np.reshape(rand_dis, (na, 1))
        else:
            sub_opt_policy[noised_state, :] = rand_dis
        est_value, _ = evaluate(env, sub_opt_policy, 20000, env_type, is_deterministic=False)

        subvalue = env.policy_evaluation(sub_opt_policy)
        print('Iteration %d: exact value: %.2f, estimated value: %.2f' % (t, subvalue, est_value))
        assert np.isclose(est_value, subvalue, rtol=0.1),\
            'The policy evaluation error exceeds 0.1 in iteration {}'.format(t)




def test_occupancy_measure(env, env_type='Episodic'):
    import copy

    opt_policy = env.get_optimal_policy()
    ns = env.observation_space.n
    na = env.action_space.n
    assert env_type in ['Episodic', 'Discounted'], 'Invalid environment type.'
    from utils.utils import estimate_occupancy_measure, estimate_discounted_occupancy_measure
    for t in range(10):
        sub_opt_policy = copy.deepcopy(opt_policy)
        rand_dis = np.random.random(size=na)
        rand_dis = rand_dis / np.sum(rand_dis)

        noised_state = np.random.choice(ns)
        if env_type == 'Episodic':
            sub_opt_policy[noised_state, :, :] = np.reshape(rand_dis, (na, 1))
            est_rho_opt = estimate_occupancy_measure(env=env, policy=sub_opt_policy, num_trajectories=40000,
                                                     is_deterministic=False)
        else:
            sub_opt_policy[noised_state, :] = rand_dis
            est_rho_opt = estimate_discounted_occupancy_measure(env=env, policy=sub_opt_policy, num_trajectories=1000,
                                                                is_deterministic=False)
        rho_opt_2 = env.calculate_occupancy_measure_v2(policy=sub_opt_policy)
        rho_opt_1 = env.calculate_occupancy_measure(policy=sub_opt_policy)
        assert np.allclose(rho_opt_1, rho_opt_2, atol=0.01), 'Fail at iteration %d' % t
        assert np.allclose(est_rho_opt, rho_opt_2, atol=0.05), 'Fail at iteration %d' % t

        print('Succeed at iteration %d' % t)