import os
import pathlib
from enum import Enum

import flax.nnx as nnx
import flax.struct as struct
import jax
import jax.numpy as jnp
import typer

from divgfn.eval_utils import (
    Config,
    EnvConfig,
    MainOutput,
    fcs,
    run_div_gfn,
    run_gfn,
    run_random_sampler,
    run_sa_gfn,
    run_teacher_gfn,
)
from divgfn.policies import GFlowNet, sample_with_mask, take
from divgfn.utils import create_queue, save_and_plot_results

app = typer.Typer(pretty_exceptions_enable=False)

HERE = os.path.dirname(os.path.abspath(__file__))


class LogRewardType(Enum):
    UTILITY = "utility"


@struct.dataclass
class EnvState:
    max_size: int = struct.field(pytree_node=False)  # Maximum carryable weight
    num_items: int = struct.field(pytree_node=False)  # Number of items

    state: jax.Array  # Binary representation of the bag

    fmask: jax.Array
    bmask: jax.Array

    # Whether the (forward or backward) process has stopped
    fstopped: jax.Array
    bstopped: jax.Array

    # Maximum trajectory length
    L: int = struct.field(pytree_node=False)

    batch_ids: jax.Array


class LogReward(nnx.Module):
    def __init__(self, num_items: int, max_size: int, seed: int, t: float = 1e-1):
        key = jax.random.key(seed)
        self.num_items = num_items
        self.max_size = max_size
        self.t = t
        self.u = jax.random.lognormal(key, shape=(num_items,))

        # Numerical stability
        self.norm = jnp.log(self.u.max() * self.max_size // 2)

    def __call__(self, x: EnvState):
        lin = jnp.einsum("bi,i->b", x.state, self.u)
        return (jnp.log(lin) - self.norm) / self.t


def create_state(max_size: int, num_items: int, bs: int):
    # Create state with batch size bs

    state = EnvState(
        max_size=max_size,
        num_items=num_items,
        state=jnp.zeros((bs, num_items)),
        fstopped=jnp.zeros((bs,), dtype=jnp.bool),
        bstopped=jnp.ones((bs,), dtype=jnp.bool),
        fmask=jnp.ones((bs, num_items)),
        bmask=jnp.zeros((bs, num_items)),
        L=max_size,
        batch_ids=jnp.arange(bs),
    )
    return state


def umask(state: jax.Array, max_size: int):
    # Calculate current weight for each batch element

    # Stop if no valid moves remain
    fstopped = state.sum(axis=1) == max_size
    bstopped = (state == 0).all(axis=1)
    bmask = (state >= 1).astype(jnp.float32)

    return fstopped, bstopped, bmask


# forward apply
def fapply(state: EnvState, actions: jax.Array) -> EnvState:
    active = ~state.fstopped
    update = jnp.where(active, 1, 0)

    new_state = state.state.at[state.batch_ids, actions].add(update)

    fstopped, bstopped, bmask = umask(new_state, state.max_size)

    return state.replace(state=new_state, fstopped=fstopped, bstopped=bstopped, bmask=bmask)


# backward apply
def bapply(state: EnvState, actions: jax.Array) -> EnvState:
    active = ~state.bstopped
    update = jnp.where(active, -1, 0)

    new_state = state.state.at[state.batch_ids, actions].add(update)

    # We remote the added items via the mask.
    fstopped, bstopped, bmask = umask(new_state, state.max_size)

    return state.replace(
        state=new_state,
        fstopped=fstopped,
        bstopped=bstopped,
        bmask=bmask,
    )


def bpol(state: EnvState, _: GFlowNet = None):
    logits = jnp.where(state.bmask == 1, 1, -jnp.inf)
    logits = nnx.log_softmax(logits, axis=1)
    return logits


def fpol(state: EnvState, gfn: GFlowNet):
    logits = gfn(state)
    logits = jnp.where(state.fmask == 1, logits, -jnp.inf)
    logits = nnx.log_softmax(logits, axis=-1)
    return logits


def fstep(carry: tuple[EnvState, jax.Array], _, gfn: nnx.Module, eps: float = 0.0):
    state, key = carry
    factive = ~state.fstopped

    # Sample actions
    flogits, actions, key = sample_with_mask(gfn, state, key, mask=state.fmask, eps=eps)
    # Move to the forward state
    fstate = fapply(state, actions)

    # Compute the backward (uniform) probabilities
    blogits = -jnp.log(fstate.bmask.sum(axis=1))

    flogits = jnp.where(factive, flogits, 0)
    blogits = jnp.where(factive, blogits, 0)

    return (fstate, key), (flogits, blogits, actions)


def bstep(carry: tuple[EnvState, jax.Array], _, gfn: nnx.Module):
    # We simply sample from a backward distribution over the unmasked states
    state, key = carry

    # Sample from categorical
    key, subkey = jax.random.split(key, 2)
    blogits = bpol(state, gfn)
    actions = jax.random.categorical(subkey, blogits)

    blogits = take(blogits, actions)
    bactive = ~state.bstopped

    # Apply to the state
    bstate = bapply(state, actions)

    # Compute the forward probabilities
    flogits = fpol(bstate, gfn)
    flogits = take(flogits, actions)

    flogits = jnp.where(bactive, flogits, 0)
    blogits = jnp.where(bactive, blogits, 0)

    return (bstate, key), (flogits, blogits, actions)


def make_output_dir(
    num_items: int,
    max_size: int,
    seed: int,
    log_reward_type: LogRewardType,
) -> pathlib.Path:
    dir_name = f"bags_items{num_items}_maxs{max_size}_seed{seed}_{log_reward_type.value}"
    output_dir = pathlib.Path("outputs") / dir_name
    output_dir.mkdir(parents=True, exist_ok=True)
    return output_dir


@app.command()
def main(
    num_items: int,
    max_size: int,
    iterations: int = 512,  # number of iterations
    bs: int = 64,  # batch size
    dmid: int = 128,  # hidden dim
    nlayers: int = 2,  # number of layers
    seed: int = 42,
    reward_seed: int = 43,
    logr_t: LogRewardType = LogRewardType.UTILITY,
) -> MainOutput:
    config = Config(
        din=num_items,
        dout=num_items,
        iterations=iterations,
        bs=bs,
        dmid=dmid,
        nlayers=nlayers,
        key=jax.random.key(seed),
    )
    env = EnvConfig(
        fstep=fstep,
        bstep=bstep,
        fapply=fapply,
        fpol=fpol,
        bpol=bpol,
    )
    output_dir = make_output_dir(num_items, max_size, seed, logr_t)

    s_o = create_state(max_size, num_items, config.bs)

    logr = LogReward(num_items, max_size, seed=reward_seed)
    queue = create_queue(num_items, capacity=int(200))

    out_div = run_div_gfn(config.replace(bs=config.bs // 2), env, logr, s_o, queue)
    out = run_gfn(config, env, logr, s_o, queue)
    out_random = run_random_sampler(config, env, logr, s_o, queue)
    out_teacher = run_teacher_gfn(config, env, logr, s_o, queue)
    out_sa = run_sa_gfn(config.replace(bs=config.bs // 2), env, logr, s_o, queue)

    # We should also implement other cooperative techniques
    iterations_for_fcs = 32
    key = config.key

    gfn_div_fcs = fcs(
        out_div.s_o,
        out_div.gfn,
        logr=logr,
        key=key,
        iterations=iterations_for_fcs,
        env=env,
    )

    gfn_fcs = fcs(
        out.s_o,
        out.gfn,
        logr=logr,
        key=key,
        iterations=iterations_for_fcs,
        env=env,
    )

    gfn_random_fcs = fcs(
        out_random.s_o,
        out_random.gfn,
        logr=logr,
        key=key,
        iterations=iterations_for_fcs,
        env=env,
    )

    gfn_teacher_fcs = fcs(
        out_teacher.s_o,
        out_teacher.gfn,
        logr=logr,
        key=key,
        iterations=iterations_for_fcs,
        env=env,
    )

    gfn_sa_fcs = fcs(
        out_sa.s_o,
        out_sa.gfn,
        logr=logr,
        key=key,
        iterations=iterations_for_fcs,
        env=env,
    )
    jax.debug.print(
        "div {} tb {} random {} teacher {} sa {}",
        gfn_div_fcs,
        gfn_fcs,
        gfn_random_fcs,
        gfn_teacher_fcs,
        gfn_sa_fcs,
    )

    # Save to `output_dir`
    save_and_plot_results(
        output_dir,
        out_div,
        out,
        out_random,
        out_teacher,
        out_sa,
        gfn_div_fcs,
        gfn_fcs,
        gfn_random_fcs,
        gfn_teacher_fcs,
        gfn_sa_fcs,
    )

    return MainOutput(
        std=out,
        div=out_div,
        teacher=out_teacher,
        sa=out_sa,
    )


if __name__ == "__main__":
    app()
