import os
import pathlib
from concurrent.futures import ThreadPoolExecutor
from enum import Enum

import flax.nnx as nnx
import flax.struct as struct
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import typer

from divgfn.eval_utils import (
    Config,
    EnvConfig,
    EvalMetricCallable,
    MainOutput,
    compute_state_prob,
    fcs,
    run_div_gfn,
    run_gfn,
    run_random_sampler,
    run_sa_gfn,
    run_teacher_gfn,
)
from divgfn.policies import GFlowNet, sample_with_mask, take
from divgfn.replay_buffer import ReplayBuffer
from divgfn.utils import create_queue, save_and_plot_results, show_kitty

app = typer.Typer(pretty_exceptions_enable=False)

HERE = os.path.dirname(os.path.abspath(__file__))


class LogRewardType(Enum):
    UTILITY = "utility"


class ModelType(Enum):
    TB = "tb"
    SA = "sa"
    TEACHER = "teacher"
    DIV = "div"


@struct.dataclass
class EnvState:
    length: int = struct.field(pytree_node=False)  # Maximum carryable weight
    dim: int = struct.field(pytree_node=False)  # Number of items

    state: jax.Array  # Binary representation of the bag

    fmask: jax.Array
    bmask: jax.Array

    # Whether the (forward or backward) process has stopped
    fstopped: jax.Array
    bstopped: jax.Array

    # Maximum trajectory length
    L: int = struct.field(pytree_node=False)

    batch_ids: jax.Array


# We should encapsulate the configurations into a config file, and make the figures properly
# We could consider further reward functiosn (e.g., Gaussian), but we should proceed cautiously


def create_mask(bs: int, dim: int):
    fmask = jnp.ones((bs, dim + 1))
    bmask = jnp.zeros((bs, dim + 1))
    return fmask, bmask


class LogReward(nnx.Module):
    def __init__(
        self,
        length: int,
        dim: int,
        ro: float = 1e-3,
        include_r1: bool = False,
        *,
        t: float = 1,
    ):
        self.length = length
        self.dim = dim
        self.ro = ro
        self.t = t
        self.include_r1 = include_r1

    def __call__(self, x: EnvState):
        ax = jnp.abs(x.state / (self.length - 1) * 2 - 1)
        r1 = (ax > 0.5).prod(-1) * 1
        r2 = ((ax < 0.8) * (ax > 0.6)).prod(-1) * 3
        if self.include_r1:
            r2 = r2 + r1
        # r3 = (ax < 0.2).prod(-1) * 1
        return jnp.log(r2 + self.ro) / self.t


def plot_grid(ax: plt.Axes, log_pi: jax.Array, states: jax.Array, length: int):
    grid = jnp.zeros((length, length))
    indices = states.astype(jnp.int32)

    grid = grid.at[indices[:, 0], indices[:, 1]].set(
        jnp.exp(log_pi),
    )
    ax.imshow(grid)
    return grid


def create_state(length: int, dim: int, bs: int):
    # Create state with batch size bs
    L = length * dim + 1
    fmask, bmask = create_mask(bs, dim)
    state = EnvState(
        length=length,
        dim=dim,
        state=jnp.zeros((bs, dim)),
        fstopped=jnp.zeros((bs,), dtype=jnp.bool),
        bstopped=jnp.ones((bs,), dtype=jnp.bool),
        fmask=fmask,
        bmask=bmask,
        L=L,
        batch_ids=jnp.arange(bs),
    )
    return state


def create_complete_state(length: int, dim: int):
    s_o = create_state(length, dim, bs=length**dim)
    bs, dim = s_o.state.shape

    # We first fill up the state with the appropriate mask
    coords = jnp.meshgrid(*(dim * [jnp.arange(length)]))
    coord = jnp.vstack(coords).reshape(dim, -1).T

    mask = jnp.zeros((bs, dim + 1))
    mask = mask.at[:, -1].set(1)

    state = s_o.replace(
        state=coord,
        fstopped=jnp.ones((bs,), dtype=jnp.bool),
        bstopped=jnp.zeros((bs,), dtype=jnp.bool),
        fmask=mask,
        bmask=mask.copy(),
    )

    return s_o, state, coord


@struct.dataclass
class Checkpoint:
    log_pt: jax.Array
    log_pi: jax.Array
    log_pt_teach: jax.Array
    log_pi_teach: jax.Array


class EvalTVCallable(EvalMetricCallable):
    def __init__(
        self,
        length: int,
        dim: int,
        freq: int,
        checkpoint: int,  # Iteration to compute the distribution
        logr: LogReward,
        key: jax.Array,
    ):
        self.length = length
        self.dim = dim
        self.freq = freq
        self.key = key
        self.counter = 1

        self.metrics = {"tv": [], "checkpoints": []}

        self.s_o, self.state, _ = create_complete_state(self.length, self.dim)
        self.log_rewards = logr(self.state)

        self.checkpoint = checkpoint

        # Allows asynchronous execution
        self.executor = ThreadPoolExecutor(max_workers=1)
        self.pending_future = None

    def has_checkpoint_been_reached(self):
        return (
            self.checkpoint is not None
            and self.counter >= self.checkpoint
            and self.counter <= self.checkpoint + self.freq
        )

    def eval_metric(self, gfn: GFlowNet, gfn_teach: GFlowNet = None):
        if self.counter % self.freq != 0 and (
            self.checkpoint is None or self.counter != self.checkpoint
        ):
            self.counter += 1
            return

        # Check if previous computation finished
        if self.pending_future is not None and self.pending_future.done():
            metric, log_pt, log_pi, log_pt_teach, log_pi_teach = (
                self.pending_future.result()
            )
            self.metrics["tv"].append(metric)

            checkpoint = Checkpoint(
                log_pt=log_pt,
                log_pi=log_pi,
                log_pt_teach=log_pt_teach,
                log_pi_teach=log_pi_teach,
            )
            self.metrics["checkpoints"].append(checkpoint)

            self.pending_future = None

        # Start new computation if none pending
        if self.pending_future is None:
            self.pending_future = self.executor.submit(
                self._compute_metric, gfn, gfn_teach
            )

        self.counter += 1

    def _compute_metric(self, gfn: GFlowNet, gfn_teach: GFlowNet):
        log_pt, log_pi = compute_state_prob(
            self.state,
            gfn,
            self.key,
            self.log_rewards,
            self.s_o,
            trajectories=10,
            bstep=bstep,
        )

        log_pt_teach = None
        log_pi_teach = None
        if gfn_teach is not None:
            log_pt_teach, log_pi_teach = compute_state_prob(
                self.state,
                gfn_teach,
                self.key,
                self.log_rewards,
                self.s_o,
                trajectories=10,
                bstep=bstep,
            )

            metric = 0.5 * jnp.abs(jnp.exp(log_pt) - jnp.exp(log_pi))

        metric = 0.5 * jnp.abs(jnp.exp(log_pt) - jnp.exp(log_pi))
        return (
            metric.sum(),
            log_pt,
            log_pi,
            log_pt_teach,
            log_pi_teach,
        )


def umask(state: jax.Array, fstopped: jax.Array, length: int):
    # We mask out actions crossing the borders of the grid
    bs, dim = state.shape
    moves = jnp.eye(dim)
    fstate = state[:, None, :] + moves[None, ...]  # (B, d)
    bstate = state[:, None, :] - moves[None, ...]  # (B, d)

    # We mask out actions leading either to state < 0 (backward) or >= length (forward)
    fmask, bmask = create_mask(bs, dim)

    fmask = fmask.at[:, :-1].set(~fstopped[..., None] & (fstate < length).all(axis=1))
    bmask = bmask.at[:, :-1].set(~fstopped[..., None] & (bstate >= 0).all(axis=1))

    bmask = bmask.at[:, -1].set(fstopped)

    return fmask, bmask


# forward apply
def fapply(state: EnvState, actions: jax.Array) -> EnvState:
    is_stop_action = actions == state.dim
    active = ~(state.fstopped | is_stop_action)
    update = jnp.where(active, 1, 0)

    new_state = state.state.at[state.batch_ids, actions].add(update)
    fstopped = jnp.where(is_stop_action, True, state.fstopped)
    bstopped = jnp.zeros_like(fstopped)

    fmask, bmask = umask(new_state, fstopped, state.length)

    return state.replace(
        state=new_state,
        fmask=fmask,
        bmask=bmask,
        fstopped=fstopped,
        bstopped=bstopped,
    )


# backward apply
def bapply(state: EnvState, actions: jax.Array) -> EnvState:
    active = ~state.bstopped
    update = jnp.where(active, -1, 0)

    new_state = state.state.at[state.batch_ids, actions].add(update)
    bstopped = (new_state == 0).all(axis=1)
    fstopped = jnp.zeros_like(bstopped)

    # We remote the added items via the mask.
    fmask, bmask = umask(new_state, fstopped, state.length)

    return state.replace(
        state=new_state,
        fmask=fmask,
        bmask=bmask,
        fstopped=fstopped,
        bstopped=bstopped,
    )


def bpol(state: EnvState, gfn: GFlowNet):
    logits = jnp.where(state.bmask == 1, 1, -jnp.inf)
    # logits = gfn.bcall(state)
    # logits = jnp.where(state.bmask == 1, logits, -jnp.inf)
    logits = nnx.log_softmax(logits, axis=1)
    return logits


def fpol(state: EnvState, gfn: GFlowNet):
    logits = gfn(state)
    logits = jnp.where(state.fmask == 1, logits, -jnp.inf)
    logits = nnx.log_softmax(logits, axis=-1)
    return logits


def fstep(carry: tuple[EnvState, jax.Array], _, gfn: nnx.Module, eps: float = 0.0):
    state, key = carry
    factive = ~state.fstopped

    # Sample actions
    flogits, actions, key = sample_with_mask(gfn, state, key, mask=state.fmask, eps=eps)

    # Move to the forward state
    fstate = fapply(state, actions)

    # Compute the backward (uniform) probabilities
    blogits = bpol(fstate, gfn)
    blogits = take(blogits, actions)

    flogits = jnp.where(factive, flogits, 0)
    blogits = jnp.where(factive, blogits, 0)

    return (fstate, key), (flogits, blogits, actions)


def bstep(carry: tuple[EnvState, jax.Array], _, gfn: nnx.Module):
    # We simply sample from a backward distribution over the unmasked states
    state, key = carry

    # Sample from categorical
    key, subkey = jax.random.split(key, 2)
    blogits = bpol(state, gfn)
    actions = jax.random.categorical(subkey, blogits)

    blogits = take(blogits, actions)
    bactive = ~state.bstopped

    # Apply to the state
    bstate = bapply(state, actions)

    # Compute the forward probabilities
    flogits = fpol(bstate, gfn)
    flogits = take(flogits, actions)

    flogits = jnp.where(bactive, flogits, 0)
    blogits = jnp.where(bactive, blogits, 0)

    return (bstate, key), (flogits, blogits, actions)


def draw_grid(
    length: int,
    dim: int,
    samplers: list[GFlowNet],
    logr: LogReward,
    key: jax.Array,
    titles: list[str],
):
    s_o, state, coord = create_complete_state(length, dim)
    log_rewards = logr(state)
    jax.debug.print("logz {}", jax.nn.logsumexp(log_rewards, axis=0))
    log_pi = nnx.softmax(log_rewards, axis=0)

    grids = {}
    ax = plt.subplot(1, len(samplers) + 1, 1)
    grids["target"] = plot_grid(ax, log_pi, coord, length)
    # Compute the log reward function
    for i, (sampler, title) in enumerate(zip(samplers, titles), start=1):
        log_pt, _ = compute_state_prob(
            state,
            sampler,
            key,
            log_rewards,
            s_o,
            trajectories=64,
            bstep=bstep,
        )
        # Plot the grid with the distribution

        ax = plt.subplot(1, len(samplers) + 1, i + 1)
        grids[f"{title}"] = plot_grid(ax, log_pt, coord, length)
        ax.set_title(title)

    show_kitty()
    # plt.show()
    return grids


def make_output_dir(
    length: int, dim: int, seed: int, log_reward_type: LogRewardType
) -> pathlib.Path:
    dir_name = f"length{length}_dim{dim}_seed{seed}_{log_reward_type.value}"
    output_dir = pathlib.Path("outputs") / dir_name
    output_dir.mkdir(parents=True, exist_ok=True)
    return output_dir


@app.command()
def main(
    length: int,
    dim: int,
    iterations: int = 512,  # number of iterations
    bs: int = 64,  # batch size
    dmid: int = 128,  # hidden dim
    nlayers: int = 2,  # number of layers
    seed: int = 42,
    logr_t: LogRewardType = LogRewardType.UTILITY,
    clip: bool = False,
    run_only: ModelType = None,
    include_r1: bool = False,
    use_scheduler: bool = True,
    lr: float = 1e-2,
    lr_teacher: float = 1e-3,
    checkpoint_at: int = None,
):
    config = Config(
        din=dim,
        dout=dim + 1,
        iterations=iterations,
        bs=bs,
        dmid=dmid,
        nlayers=nlayers,
        key=jax.random.key(seed),
        should_clip=clip,
        lr=lr,
        lr_teacher=lr_teacher,
        use_scheduler=use_scheduler,
    )

    env = EnvConfig(
        fstep=fstep,
        bstep=bstep,
        fapply=fapply,
        fpol=fpol,
        bpol=bpol,
    )
    output_dir = make_output_dir(length, dim, seed, logr_t)

    s_o = create_state(length, dim, config.bs)

    buffer = ReplayBuffer.create(create_state(length, dim, config.buffer_size))

    logr = LogReward(length, dim, include_r1=include_r1)
    queue = create_queue(dim, capacity=int(2**dim))
    key = config.key

    func_to_model = {
        ModelType.TB: run_gfn,
        ModelType.SA: run_sa_gfn,
        ModelType.TEACHER: run_teacher_gfn,
        ModelType.DIV: run_div_gfn,
    }

    key, subkey = jax.random.split(key, 2)

    def create_eval():
        return EvalTVCallable(
            length, dim, freq=128, checkpoint=checkpoint_at, logr=logr, key=subkey
        )

    if run_only:
        out = func_to_model[run_only](
            config, env, logr, s_o, queue, buffer, eval_metric=create_eval()
        )
        if dim == 2:
            grids = draw_grid(
                length, dim, [out.gfn], logr, key=key, titles=[run_only.value]
            )
            if out.gfn_teach is not None:
                grids_teach = draw_grid(
                    length, dim, [out.gfn_teach], logr, key=key, titles=[run_only.value]
                )
            else:
                grids_teach = None
        else:
            grids = None
            grids_teach = None

        typer.echo(out.extra_metrics)
        return out, (grids, grids_teach)

    out_div = run_div_gfn(
        config.replace(bs=config.bs // 2),
        env,
        logr,
        s_o,
        queue,
        buffer,
        eval_metric=create_eval(),
    )
    # grids =
    if dim == 2:
        draw_grid(length, dim, [out_div.gfn], logr, key=key, titles=["div"])
        draw_grid(length, dim, [out_div.gfn_teach], logr, key=key, titles=["div"])

    out_sa = run_sa_gfn(
        config.replace(bs=config.bs // 2),
        env,
        logr,
        s_o,
        queue,
        eval_metric=create_eval(),
    )
    if dim == 2:
        draw_grid(length, dim, [out_sa.gfn], logr, key=key, titles=["sa"])
        draw_grid(length, dim, [out_sa.gfn_teach], logr, key=key, titles=["sa"])

    # Plot the TV
    for name, m in [("div", out_div), ("sa", out_sa)]:
        plt.plot(m.extra_metrics["tv"], label=name)
    plt.legend()

    show_kitty()

    out = run_gfn(config, env, logr, s_o, queue, buffer, eval_metric=create_eval())
    if dim == 2:
        draw_grid(length, dim, [out.gfn], logr, key=key, titles=["tb"])
    out_random = run_random_sampler(config, env, logr, s_o, queue)
    out_teacher = run_teacher_gfn(
        config, env, logr, s_o, queue, eval_metric=create_eval()
    )
    if dim == 2:
        draw_grid(length, dim, [out_teacher.gfn], logr, key=key, titles=["teacher"])
        draw_grid(
            length, dim, [out_teacher.gfn_teach], logr, key=key, titles=["teacher"]
        )

    key, subkey = jax.random.split(config.key, 2)
    if dim == 2:
        grids = draw_grid(
            length,
            dim,
            [out_div.gfn, out.gfn, out_random.gfn, out_teacher.gfn, out_sa.gfn],
            logr,
            key=subkey,
            titles=["div", "tb", "random", "teacher", "sa"],
        )
    else:
        grids = None

    # We should also implement other cooperative techniques
    iterations_for_fcs = 32

    gfn_div_fcs = fcs(
        out_div.s_o,
        out_div.gfn,
        logr=logr,
        key=key,
        iterations=iterations_for_fcs,
        env=env,
    )

    gfn_fcs = fcs(
        out.s_o,
        out.gfn,
        logr=logr,
        key=key,
        iterations=iterations_for_fcs,
        env=env,
    )

    gfn_random_fcs = fcs(
        out_random.s_o,
        out_random.gfn,
        logr=logr,
        key=key,
        iterations=iterations_for_fcs,
        env=env,
    )

    gfn_teacher_fcs = fcs(
        out_teacher.s_o,
        out_teacher.gfn,
        logr=logr,
        key=key,
        iterations=iterations_for_fcs,
        env=env,
    )

    gfn_sa_fcs = fcs(
        out_sa.s_o,
        out_sa.gfn,
        logr=logr,
        key=key,
        iterations=iterations_for_fcs,
        env=env,
    )
    jax.debug.print(
        "div {} tb {} random {} teacher {} sa {}",
        gfn_div_fcs,
        gfn_fcs,
        gfn_random_fcs,
        gfn_teacher_fcs,
        gfn_sa_fcs,
    )

    # # Plot the TV
    # for name, m in [
    #     ("div", out_div),
    #     ("tb", out),
    #     ("random", out_random),
    #     ("teacher", out_teacher),
    #     ("sa", out_sa),
    # ]:
    #     plt.plot(m.extra_metrics["tv"], label=name)
    plt.legend()

    show_kitty()

    # Save to `output_dir`
    save_and_plot_results(
        output_dir,
        out_div,
        out,
        out_random,
        out_teacher,
        out_sa,
        gfn_div_fcs,
        gfn_fcs,
        gfn_random_fcs,
        gfn_teacher_fcs,
        gfn_sa_fcs,
    )

    return (
        MainOutput(
            std=out,
            div=out_div,
            teacher=out_teacher,
            sa=out_sa,
        ),
        grids,
    )


if __name__ == "__main__":
    app()
