import jax
import jax.numpy as jnp
from flax import struct

from mfax.algos.hsm.evaluate import evaluate_given_sequence
from mfax.algos.hsm.sequence import make_mf_sequence, mf_sequence
from mfax.algos.hsm.best_response import br
from mfax.envs.pushforward.base import PushforwardMFSequence


@struct.dataclass
class ExploitabilityResults:
    exploitability: float
    mean_policy_return: float
    mean_br_return: float


@struct.dataclass
class BRResults:
    disc_returns: jnp.ndarray
    actions: jnp.ndarray
    rewards: jnp.ndarray


@struct.dataclass
class EvalResults:
    exploitability: ExploitabilityResults
    mf_sequence: PushforwardMFSequence
    policy_disc_returns: jnp.ndarray
    br: BRResults


def exploitability(
    rng,
    env,
    agent,
    state_type,
    gamma,
    num_envs,
    max_steps_in_episode,
) -> tuple[float, float, jnp.ndarray, jnp.ndarray]:
    """
    Calculate the exploitability of a given policy.
    """
    # --- generate trajectory sequence using the given policy ---
    traj_batch = mf_sequence(
        rng, env, agent, state_type, num_envs=num_envs, max_steps_in_episode=max_steps_in_episode
    )

    # --- evaluate policy's performance ---
    policy_disc_returns, _ = evaluate_given_sequence(
        env, traj_batch, gamma=gamma
    )
    
    # --- compute best-response rewards against policy's mean-field trajectory ---
    br_disc_returns, _, br_actions, br_rewards = br(
        env, traj_batch, gamma=gamma
    )
    
    # --- rollout training optimizes mean-field-weighted value at t=0, so weight BR and policy values by m0 --- 
    mean_br_return = (br_disc_returns * traj_batch.global_s.m[0]).sum(axis=-1).mean()
    mean_policy_return = (policy_disc_returns * traj_batch.global_s.m[0]).sum(axis=-1).mean()
    exploitability = mean_br_return - mean_policy_return

    return EvalResults(ExploitabilityResults(exploitability, mean_policy_return, mean_br_return), traj_batch, policy_disc_returns, BRResults(br_disc_returns, br_actions, br_rewards))


def make_exploitability(
    env,
    agent,
    state_type,
    gamma,
    num_envs,
    max_steps_in_episode,
    ):

    mf_sequence = make_mf_sequence(env, agent, num_envs, max_steps_in_episode, state_type)

    @jax.jit
    def _exploitability(
        rng,
        agent_params,
    ) -> tuple[float, float, jnp.ndarray, jnp.ndarray]:
        """
        Calculate the exploitability of a given policy.
        """
        # --- generate trajectory sequence using the given policy ---
        traj_batch = mf_sequence(rng, agent_params)

        # --- evaluate policy's performance ---
        policy_disc_returns, _ = evaluate_given_sequence(
            env, traj_batch, gamma=gamma
        )
        
        # --- compute best-response rewards against policy's mean-field trajectory ---
        br_disc_returns, _, br_actions, br_rewards = br(
            env, traj_batch, gamma=gamma
        )
        
        # --- rollout training optimizes mean-field-weighted value at t=0, so weight BR and policy values by m0 --- 
        mean_br_return = (br_disc_returns * traj_batch.global_s.m[0]).sum(axis=-1).mean()
        mean_policy_return = (policy_disc_returns * traj_batch.global_s.m[0]).sum(axis=-1).mean()
        exploitability = mean_br_return - mean_policy_return

        return EvalResults(ExploitabilityResults(exploitability, mean_policy_return, mean_br_return), traj_batch, policy_disc_returns, BRResults(br_disc_returns, br_actions, br_rewards))

    return _exploitability
