"""
An example script for running the Lending Gym environment and plotting the results.
"""

import numpy as np

import gymnasium
import fair_gym

from fair_gym.utils import plotting
from fair_gym import CollegeAdmissionMetrics

# Define the environment parameters
env_kwargs = {
    "n_groups": 2,
    "group_distribution": (0.5, 0.5),
    "score_distribution_mean": (8, 6),
    "score_distribution_std": (1, 1),
    "budget_distribution_mean": (3, 2),
    "budget_distribution_std": (1, 1),
    "max_budget": 5,
    "epsilon": 0.5,
    "population_size": 1000,
}
episode_length = 1000

# Create the environment and the metrics
np.random.seed(0)

env = gymnasium.make(
    "fair_gym/CollegeAdmissionEnv", max_episode_steps=episode_length, **env_kwargs
)
metrics = CollegeAdmissionMetrics(env)

# Run the environment
obs, info = env.reset(seed=0)
terminated, truncated = False, False
time_step = 0

while not terminated and not truncated:
    time_step += 1
    action = env.action_space.sample()
    obs, reward, terminated, truncated, info = env.step(action)

# Get the metrics
results = metrics.get_all_metrics()

# Plot the initial score and budget distributions
plotting.plot_distribution(
    results["initial_score_distribution"],
    title="initial_score",
    path="plots",
)
plotting.plot_distribution(
    results["initial_budget_distribution"],
    title="initial_budget",
    path="plots",
)

# Plot the current score and budget distributions
plotting.plot_distribution(
    results["final_score_distribution"],
    title="final_score",
    path="plots",
)
plotting.plot_distribution(
    results["final_budget_distribution"],
    title="final_budget",
    path="plots",
)

# Plot cumulative addmisions per step
plotting.plot_cumulative_metric(
    results["cumulative_admissions"],
    title="cumulative_admissions",
    path="plots",
)

# Plot the accepted distribution
plotting.plot_acceptance_ratio(
    results["acceptance_ratio"],
    title="acceptance_ratio",
    path="plots",
)

print(f"Time step: {time_step}")
# Print the agent's recall and precision for each group
print(f"Agent's recall: {results['recall']}")
print(f"Agent's precision: {results['precision']}")
# Print the average cost paid by accepted students
print(
    f"Average cost paid by accepted students: {results['average_cost_paid_by_accepted']}"
)
