import json
import os
import pathlib
import subprocess
import warnings
from io import BytesIO
from typing import Protocol

import flax.nnx as nnx
import flax.struct as struct
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

warnings.filterwarnings("ignore", category=RuntimeWarning, message="invalid value encountered in cast")


def is_kitty():
    return os.environ.get("TERM", "").startswith("xterm-kitty")


def show_kitty():
    if not is_kitty():
        return

    buf = BytesIO()
    plt.savefig(buf, format="png")
    buf.seek(0)

    # Pipe directly to kitty
    subprocess.run(["kitty", "+kitten", "icat"], input=buf.getvalue())
    plt.close()


class HasState(Protocol):
    state: jax.Array


def merge(x, y):
    # Merge EnvStates
    if isinstance(x, jax.Array):
        return jnp.concatenate([x, y], axis=0)
    return x


@struct.dataclass
class RunOutput:
    s_o: HasState
    gfn: nnx.Module
    queue: struct.PyTreeNode
    logzs: list[float]
    hlogr: list[float]
    hr: list[float]

    extra_metrics: list[float] = struct.field(default_factory=list)

    gfn_teach: nnx.Module = None


@struct.dataclass
class ModeQueue:
    capacity: int = struct.field(pytree_node=False)

    x: jax.Array
    y: jax.Array

    mask: jax.Array  # Represents which states are valid
    hlogr: jax.Array  # Average log reward
    hr: jax.Array  # Average reward


def create_queue(dim: int, capacity: int = int(5e1)):
    return ModeQueue(
        capacity=capacity,
        x=jnp.zeros((capacity, dim)),
        y=-jnp.ones((capacity,)) * jnp.inf,
        mask=jnp.zeros((capacity,), dtype=bool),
        hlogr=None,
        hr=None,
    )


@jax.jit
def push(queue: ModeQueue, state: HasState, logr: jax.Array):
    # We merge state.state with queue.x, sort by unique with size = capacity
    # We update the mask with those elements whose log-reward != \infty

    x = jnp.vstack([queue.x, state.state])
    y = jnp.hstack([queue.y, logr])

    # Merge [y | x]
    yx = jnp.hstack([-y[..., None], x])
    max_int_millionth = 2_147
    min_int_millionth = -2_147

    clean_scaled = jnp.nan_to_num(yx, nan=0, neginf=min_int_millionth, posinf=max_int_millionth)
    yx_millionth = jnp.floor(1_000_000 * clean_scaled).astype(jnp.int32)

    _, idx = jnp.unique(ar=yx_millionth, size=queue.capacity, return_index=True, fill_value=-jnp.inf, axis=0)

    x = yx[idx, 1:]
    y = -yx[idx, 0]
    mask = ~jnp.isinf(y)

    avg_logr = jnp.where(mask, y, 0).sum() / mask.sum()
    avg_r = jnp.where(mask, jnp.exp(y), 0).sum() / mask.sum()

    return queue.replace(x=x, y=y, mask=mask, hlogr=avg_logr, hr=avg_r)


def save_and_plot_results(
    output_dir: pathlib.Path,
    out_div: RunOutput,
    out: RunOutput,
    out_random: RunOutput,
    out_teacher: RunOutput,
    out_sa: RunOutput,
    fcs_div: jax.Array,
    fcs_tb: jax.Array,
    fcs_random: jax.Array,
    fcs_teacher: jax.Array,
    fcs_sa: jax.Array,
):
    results = {
        "fcs_scores": {
            "div": fcs_div.item(),
            "tb": fcs_tb.item(),
            "random": fcs_random.item(),
            "teacher": fcs_teacher.item(),
            "sa": fcs_sa.item(),
        },
        "logzs": {"div": out_div.logzs, "tb": out.logzs, "random": out_random.logzs, "teacher": out_teacher.logzs, "sa": out_sa.logzs},
        "hlogr": {"div": out_div.hlogr, "tb": out.hlogr, "random": out_random.hlogr, "teacher": out_teacher.hlogr, "sa": out_sa.hlogr},
        "hr": {"div": out_div.hr, "tb": out.hr, "random": out_random.hr, "teacher": out_teacher.hr, "sa": out_sa.hr},
    }
    with open(output_dir / "results.json", "w") as f:
        json.dump(results, f, indent=2)

    plt.subplot(1, 3, 1)
    plt.plot(out_div.logzs, label="div")
    plt.plot(out.logzs, label="tb")
    plt.plot(out_random.logzs, label="random")
    plt.plot(out_teacher.logzs, label="teacher")
    plt.plot(out_sa.logzs, label="sa")
    plt.legend()

    plt.subplot(1, 3, 2)
    plt.plot(out_div.hlogr, label="div")
    plt.plot(out.hlogr, label="tb")
    plt.plot(out_random.hlogr, label="random")
    plt.plot(out_teacher.hlogr, label="teacher")
    plt.plot(out_sa.hlogr, label="sa")

    plt.subplot(1, 3, 3)
    plt.plot(out_div.hr, label="div")
    plt.plot(out.hr, label="tb")
    plt.plot(out_random.hr, label="random")
    plt.plot(out_teacher.hr, label="teacher")
    plt.plot(out_sa.hr, label="sa")

    plt.legend()
    plt.tight_layout()
    if is_kitty():
        show_kitty()
    else:
        plt.show()
