from collections import defaultdict
from typing import Any, Callable
from warnings import catch_warnings, filterwarnings

from flax import nnx
from gymnasium import Env
from gymnasium.spaces import Box
from jax import Array, jit, numpy as jnp
import numpy as np
import tensorflow as tf

from offline.envs.registration import GetNormalizedScore
from offline.modules.policy import Policy, StateT
from offline.types import ArrayLike, FloatArray, SummaryWriter


ActFunction = Callable[
    [
        nnx.GraphDef[Policy[StateT]],
        nnx.GraphState | nnx.VariableState,
        ArrayLike,
        StateT,
    ],
    tuple[Array, StateT, dict[str, Any]],
]


def compile_act(
    action_space: Box,
    mean: FloatArray | float,
    std: FloatArray | float,
    unsquash: bool,
) -> ActFunction[StateT]:
    high, low = action_space.high, action_space.low
    loc = np.expand_dims((high + low) / 2, 0)
    scale = np.expand_dims((high - low) / 2, 0)
    shape = (-1,) + action_space.shape[1:]

    @jit
    def inner(
        graphdef: nnx.GraphDef[Policy[StateT]],
        graphstate: nnx.GraphState | nnx.VariableState,
        observations: ArrayLike,
        state: StateT,
    ) -> tuple[Array, StateT, dict[str, Any]]:
        policy = nnx.merge(graphdef, graphstate)
        observations_jnp: Array = (jnp.asarray(observations) - mean) / std
        actions, state, info = policy(observations_jnp, state)
        if unsquash:
            actions = jnp.tanh(actions)
        actions = actions.reshape(shape)
        actions = actions * scale + loc
        return actions, state, info

    return inner


def evaluate_seed(
    act_fn: ActFunction[StateT],
    env: Env,
    policy: Policy[StateT],
    seed: int,
    state: StateT,
) -> tuple[float, int]:
    observation: FloatArray
    observation, _ = env.reset(seed=seed)
    observation = observation.ravel()
    done = False
    total_reward = 0.0
    length = 0
    graphdef, graphstate = nnx.split(policy)
    with catch_warnings():
        filterwarnings("ignore", category=DeprecationWarning)
        while not done:
            action, state, _ = act_fn(graphdef, graphstate, observation, state)
            action = action[0]
            observation, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            observation = observation.ravel()
            total_reward += float(reward)
            length += 1
    return total_reward, length


def evaluate(
    act_fn: ActFunction[StateT],
    env: Env,
    eval_episodes: int,
    get_normalized_score: GetNormalizedScore | None,
    policy: Policy[StateT],
    seed: int,
    state: StateT,
) -> dict[str, FloatArray]:
    episode_lengths: list[int] = []
    episode_returns: list[float] = []
    for index in range(eval_episodes):
        total_reward, length = evaluate_seed(
            act_fn=act_fn,
            env=env,
            policy=policy,
            seed=seed + index,
            state=state,
        )
        episode_returns.append(total_reward)
        episode_lengths.append(length)
    episode_returns_np = np.asarray(episode_returns)
    results = {
        "returns": np.asarray(episode_returns),
        "lengths": np.asarray(episode_lengths),
    }
    if get_normalized_score is not None:
        results["normalized_scores"] = get_normalized_score(episode_returns_np)
    return results


def log_evaluation_results(
    results: dict[str, FloatArray], step: int, writer: SummaryWriter
):
    returns = results["returns"]
    lengths = results["lengths"]
    with writer.as_default(step=step):
        tf.summary.scalar("eval/length", np.mean(lengths))
        tf.summary.scalar("eval/length_std", np.std(lengths))
        tf.summary.scalar("eval/reward", np.mean(returns))
        tf.summary.scalar("eval/reward_std", np.std(returns))
        if "normalized_scores" in results:
            norm_scores = results["normalized_scores"]
            tf.summary.scalar("eval/norm", np.mean(norm_scores))
            tf.summary.scalar("eval/norm_std", np.std(norm_scores))


def rollout(
    act_fn: ActFunction[StateT],
    env: Env,
    policy: Policy[StateT],
    seed: int,
    state: StateT,
):
    results: dict[str, list[Any]] = defaultdict(list)
    observation: FloatArray
    observation, _ = env.reset(seed=seed)
    observation = observation.ravel()
    results["observations"].append(np.copy(observation))
    done = False
    graphdef, graphstate = nnx.split(policy)
    with catch_warnings():
        filterwarnings("ignore", category=DeprecationWarning)
        while not done:
            action, state, info = act_fn(
                graphdef, graphstate, observation, state
            )
            observation, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            observation = observation.ravel()
            results["actions"].append(np.copy(action))
            results["dones"].append(done)
            results["observations"].append(np.copy(observation))
            results["reward"].append(reward)
            results["terminated"].append(terminated)
            results["truncated"].append(truncated)
            for key, value in info.items():
                results[f"info/{key}"].append(np.copy(value))
    return {key: np.asarray(value) for key, value in results.items()}
