from typing import Any
import numpy as np
import shutil
from gymnasium import Env
import gymnasium as gym
import torch
from regawa.gnn import GraphAgent
from gymnasium.spaces import MultiDiscrete, Dict
import numpy as np


def _format(state: dict[str, Any], width: int = 80, indent: int = 4):
    if len(state) == 0:
        return str(state)
    state = {str(key): str(value) for (key, value) in state.items()}
    klen = max(map(len, state.keys())) + 1
    vlen = max(map(len, state.values())) + 1
    cols = max(1, (width - indent) // (klen + vlen + 3))
    result = " " * indent
    for count, (key, value) in enumerate(state.items(), 1):
        result += f"{key.rjust(klen)} = {value.ljust(vlen)}"
        if count % cols == 0:
            result += "\n" + " " * indent
    return result


@torch.inference_mode()
def evaluate(
    agent: GraphAgent,
    env: Env,
    episodes: int = 1,
    discount: float = 1.0,
    verbose: bool = False,
    render: bool = False,
    seed: int | None = None,
    deterministic: bool = True,
) -> tuple[dict[str, float], np.typing.NDArray[np.float64]]:
    gamma = discount

    # get terminal width
    if verbose:
        width = shutil.get_terminal_size().columns
        sep_bar = "-" * width

    # start simulation
    history = np.zeros((episodes,))
    for episode in range(episodes):
        # restart episode
        total_reward, cuml_gamma = 0.0, 1.0

        state, info = env.reset(seed=seed + episode)

        # printing
        if verbose:
            print(f"initial state = \n{_format(info['rddl_state'], width)}")

        # simulate to end of horizon
        done = False
        step = 0
        while not done:
            # take a step in the environment
            action, *_ = agent.sample_from_obs(state, deterministic=deterministic)
            np_action = action.squeeze().cpu().detach().numpy()
            next_state, reward, terminated, truncated, info = env.step(np_action)
            total_reward += reward * cuml_gamma
            cuml_gamma *= gamma
            done = terminated or truncated

            # printing
            if verbose:
                print(
                    f"{sep_bar}\n"
                    f"step   = {step}\n"
                    f"action = \n{_format(info['rddl_action'], width)}\n"
                    f"state  = \n{_format(info['rddl_state'], width)}\n"
                    f"reward = {reward}\n"
                    f"done   = {done}"
                )
            state = next_state
            step += 1
            if done:
                break

        if verbose:
            print(
                f'\n'
                f'episode {episode + 1} ended with return {total_reward}\n'
                f'{"=" * width}'
            )
        history[episode] = total_reward

    # summary statistics
    return {
        "mean": np.mean(history),
        "median": np.median(history),
        "min": np.min(history),
        "max": np.max(history),
        "std": np.std(history),
    }, history


@torch.inference_mode()
def evaluate_instance(
    env_id: str,
    domain: str,
    instance: int,
    agent: GraphAgent,
    remove_false: bool,
    episodes: int = 10,
    verbose: bool = False,
    deterministic=True,
):
    env: gym.Env[Dict, MultiDiscrete] = gym.make(  # type: ignore
        env_id,
        domain=domain,
        instance=instance,
        remove_false=remove_false,
        optimize=True,
    )

    data, h = evaluate(
        agent, env, episodes, verbose=verbose, seed=0, deterministic=deterministic
    )
    return data, h
