#!/usr/bin/env python

"""
WARN: Greedy agent implemented with gradient descent; will only
find local optima given non-convex objective

NOTE: Greedy agent has access dynamics depending on loss function used.
If loss function does NOT depend on `current_state` (see
fairgym/envs/default_reward.py), then no information about dynamics is used. If
loss functions DOES depends on `current_state`, then information about induced
dynamics is used to optimize policy.
"""

from jax import numpy as jnp

from fairgym.plotting.phase import make_replicator_phase_plot
from fairgym.envs.default_reward import known_loss_func_dict
from fairgym.envs.base_env import BaseEnv
from functools import partial

import gymnasium as gym
import argparse

from fairgym.agent import GreedyAgent
from fairgym.experiment import Experiment

# Use to view when/how many times Jax traces and compiles functions
# import jax
# jax.config.update("jax_log_compiles", True)


def to_plot(env, agent, info):

    state = info["current_state"]
    results = info["prev_results"]

    out = {
        "Qualification Rate (g0)": (state.pr_Y1[0], "#13a5cd"),
        "Qualification Rate (g1)": (state.pr_Y1[1], "#ff24da"),
    }

    # very first render does not have past action
    if results is not None:
        out = out | {
            "Accept Rate (g0)": (results.accept_rate[0], "green"),
            "Accept Rate (g1)": (results.accept_rate[1], "red"),
            # 'Disparity': (group_disparity(info), ''),
            # 'Utility': (classifier_utility(info), ''),
            "Reward": (info["reward"], "black"),
        }

    return out


################################################################################
# //  ____                            _
# // |  _ \ _____      ____ _ _ __ __| |
# // | |_) / _ \ \ /\ / / _` | '__/ _` |
# // |  _ <  __/\ V  V / (_| | | | (_| |
# // |_| \_\___| \_/\_/ \__,_|_|  \__,_|
# //


def _reward_fn(
    lambda_,
    loss_term,
    disparity_term,
    prev_state,
    prev_action,
    prev_results,
    current_state,
):
    """
    What the environment's reward signal is, as a function
    of the action we just took

    TODO Even though this gets jitted in env.step, we could jit it here to hold the
     (changing on reset) args as static.
    """
    loss = loss_term(prev_state, prev_action, prev_results, current_state)
    disparity = disparity_term(prev_state, prev_action, prev_results, current_state)

    return 1 - (loss * (1 - lambda_) + lambda_ * disparity)


################################################################################
# //  ____
# // |  _ \ _   _ _ __
# // | |_) | | | | '_ \
# // |  _ <| |_| | | | |
# // |_| \_\\__,_|_| |_|
# //
# Run simulation


def run_experiment(
    title,
    reward_fn,
    loss_name,
    disparity_name,
    env_seed=0,
    res=4,
    H=False,
    num_runs=3,
    adult=False,
):

    if adult:
        env: BaseEnv = gym.make(
            "AdultQualificationReplicator-v0", reward_fn=reward_fn, use_jit=True
        )
    else:
        env: BaseEnv = gym.make(
            "GroupQualificationReplicator-v0", reward_fn=reward_fn, use_jit=True
        )

    agent = GreedyAgent(env, use_jit=True)

    # experiment = Experiment(agent, env)
    # experiment.reset(
    #     seed=env_seed,
    #     options={
    #         #         "pr_q": [0.2, 0.8],
    #         "pr_G": [0.5, 0.5]
    #     },
    # )

    # print("Recording single-run experiment")
    # observation, reward, terminated, truncated, info = experiment.record(
    #     title, n_steps=100, to_plot=to_plot, render_flag=True
    # )

    make_replicator_phase_plot(
        agent,
        env,
        loss_name,
        disparity_name,
        env_seed=env_seed,
        filename=title,
        res=res,
        options={"pr_G": [0.5, 0.5]},
        H=H,
        num_runs=num_runs,
    )


if __name__ == "__main__":

    parser = argparse.ArgumentParser(description="")

    parser.add_argument(
        "--loss",
        type=str,
        default="tp_loss",
    )

    parser.add_argument(
        "--disparity",
        type=str,
        default="Demographic Disparity",
    )

    # weight of disparity function
    parser.add_argument(
        "--lbd",
        type=float,
        default=0.0,
    )

    parser.add_argument(
        "--env_seed",
        type=int,
        default=0,
    )

    args = parser.parse_args()

    loss_name = args.loss
    lbd = args.lbd
    disparity_name = args.disparity
    env_seed = args.env_seed

    print(1 - lbd, loss_name, "+", lbd, disparity_name)

    reward_fn = partial(
        _reward_fn,
        lbd,
        known_loss_func_dict[loss_name],
        known_loss_func_dict[disparity_name],
    )

    for adult in [True, False]:

        title = f"{1-lbd}{loss_name}+{lbd}{disparity_name}_env_seed={env_seed}_Adult={adult}"

        H = 150 if adult else 100

        # generate phase diagrams
        print("Generating high-resolution phase plot")
        run_experiment(
            title,
            reward_fn=reward_fn,
            loss_name=loss_name,
            disparity_name=disparity_name,
            res=32,
            H=False,
            num_runs=3,
            adult=adult,
        )

        #
        print("Saving mean episodic losses")
        run_experiment(
            title,
            reward_fn=reward_fn,
            loss_name=loss_name,
            disparity_name=disparity_name,
            res=4,
            H=H,
            num_runs=3,
            adult=adult,
        )
