"""
Contains functions that generate the paper (and other) plots.
"""

import numpy as np

from environments import StarMDP_with_random_flinging, Gridworld
from models.policies import train_tabular_BC_policy
from utils.offline_helpers import generate_offline_trajectories
from utils.online_helpers import compute_rewards_traj, rollout_policy_in_env


def get_mle_policy_avg_reward(params, N_seeds=10):
    if params["env"] == "StarMDP_with_random_flinging":
        env_true = StarMDP_with_random_flinging(
            discount_factor=0.99,
            episode_length=params["episode_length"],
            move_prob=params["env_move_prob"],
        )
        solution_pi_true = env_true.get_lp_solution()
    elif params["env"] == "Gridworld":
        env_true = Gridworld(
            width=4,
            height=4,
            episode_length=params["episode_length"],
            discount_factor=0.99,
            random_action_prob=1 - params["env_move_prob"],
        )
        solution_pi_true = env_true.get_lp_solution()
    else:
        raise ValueError(f"Environment {params['env']} not supported")
    mle_rewards_per_seed = []
    opt_rewards_per_seed = []
    mle_equals_opt_count = 0
    for seed in range(N_seeds):
        offline_trajs, __ = generate_offline_trajectories(
            env_true, solution_pi_true, n_samples=params["N_offline_trajs"]
        )
        mle_policy = train_tabular_BC_policy(
            offline_trajs,
            env_true.N_states,
            env_true.N_actions,
            init="random",
            n_epochs=10,
            lr=0.01,
            make_deterministic=True,
        )
        mle_policy.matrix = mle_policy.matrix.detach().numpy()
        # calculate average reward of mle and optimal policies on 10k trajectories
        mle_rewards = []
        opt_rewards = []
        for _ in range(10000):
            traj = rollout_policy_in_env(env_true, mle_policy)
            traj_reward = compute_rewards_traj(traj, env_true.rewards, env_true.discount_factor)
            mle_rewards.append(traj_reward)
            opt_traj = rollout_policy_in_env(env_true, solution_pi_true)
            opt_traj_reward = compute_rewards_traj(
                opt_traj, env_true.rewards, env_true.discount_factor
            )
            opt_rewards.append(opt_traj_reward)
        # if BC policy is = optimal policy
        if np.allclose(mle_policy.matrix, solution_pi_true.matrix):
            mle_equals_opt_count += 1
        mle_rewards_per_seed.append(np.mean(mle_rewards))
        opt_rewards_per_seed.append(np.mean(opt_rewards))
    mean_mle_rewards = np.mean(mle_rewards_per_seed)
    mean_opt_rewards = np.mean(opt_rewards_per_seed)
    if mle_equals_opt_count == N_seeds:
        print("MLE policy is = optimal policy on all seeds, returning None for the MLE avg reward")
        return None, mean_opt_rewards
    else:
        print(f"MLE policy is = optimal policy on {mle_equals_opt_count} seeds")
        return mean_mle_rewards, mean_opt_rewards
