import os
import pathlib
from enum import Enum

import flax.nnx as nnx
import flax.struct as struct
import jax
import jax.numpy as jnp
import typer

from divgfn.eval_utils import (
    Config,
    EnvConfig,
    MainOutput,
    fcs,
    run_div_gfn,
    run_gfn,
    run_random_sampler,
    run_sa_gfn,
    run_teacher_gfn,
)
from divgfn.policies import GFlowNet, sample_with_mask, take
from divgfn.utils import create_queue, save_and_plot_results

app = typer.Typer(pretty_exceptions_enable=False)

HERE = os.path.dirname(os.path.abspath(__file__))


class LogRewardType(Enum):
    UTILITY = "utility"


@struct.dataclass
class EnvState:
    max_weight: int = struct.field(pytree_node=False)  # Maximum carryable weight
    num_items: int = struct.field(pytree_node=False)  # Number of items
    limit: int = struct.field(pytree_node=False)  # maximum number of a speicfic item

    weights: jax.Array
    state: jax.Array  # Binary representation of the bag

    fmask: jax.Array
    bmask: jax.Array

    # Whether the (forward or backward) process has stopped
    fstopped: jax.Array
    bstopped: jax.Array

    # Maximum trajectory length
    L: int = struct.field(pytree_node=False)

    batch_ids: jax.Array


def make_knapsack_reward(
    key: jax.Array,
    num_items: int,
    sparsity: float = 0.15,
    sigma: float = 1.0,
    weights: jax.Array = None,
):
    key, k1, k2, k3 = jax.random.split(key, 4)
    # Ensure u is positive
    u = jax.random.uniform(k1, (num_items,), minval=0.1, maxval=1)

    # Create sparse upper-triangular matrix
    mask = jax.random.bernoulli(k2, p=sparsity, shape=(num_items, num_items))
    mask = jnp.triu(mask, k=1)
    vals = jax.random.normal(k3, (num_items, num_items)) * sigma
    A_sparse = vals * mask

    # Make symmetric
    A_sparse = A_sparse + A_sparse.T

    row_sums = jnp.sum(jnp.abs(A_sparse), axis=1)
    A = A_sparse + jnp.diag(row_sums + 0.1)  # +0.1 for extra margin

    if weights is not None:
        A = A / (jnp.sqrt(weights[:, None] * weights[None, :]) + 1e-6)

    return u, A


class LogReward(nnx.Module):
    def __init__(self, num_items: int, weights: jax.Array, seed: int, t: float = 1):
        key = jax.random.key(seed)
        self.num_items = num_items
        self.t = t
        self.u, self.A = make_knapsack_reward(key, num_items, weights=weights)

    def __call__(self, x: EnvState):
        lin = jnp.einsum("bi,i->b", x.state, self.u)
        quad = 0.5 * jnp.einsum("bi,ij,bj->b", x.state, self.A, x.state)
        return jnp.log(lin + quad)


def create_state(max_weight: int, num_items: int, limit: int, bs: int, key: jax.Array):
    # Create state with batch size bs
    minimum_item_weight = 2
    assert max_weight >= minimum_item_weight
    weights = jax.random.randint(
        key,
        shape=(num_items,),
        minval=minimum_item_weight,
        maxval=max_weight - 1,
    )

    # weights = jnp.ones((num_items,)) * minimum_item_weight
    L = jnp.ceil(max_weight / minimum_item_weight)
    L = int(min(L, num_items * limit + 1))
    jax.debug.print("Trajectory length: {}", L)

    state = EnvState(
        max_weight=max_weight,
        num_items=num_items,
        limit=limit,
        weights=weights,
        state=jnp.zeros((bs, num_items)),
        fstopped=jnp.zeros((bs,), dtype=jnp.bool),
        bstopped=jnp.ones((bs,), dtype=jnp.bool),
        fmask=jnp.ones((bs, num_items)),
        bmask=jnp.zeros((bs, num_items)),
        L=L,
        batch_ids=jnp.arange(bs),
    )
    return state


def umask(state: jax.Array, max_weight: int, limit: int, weights: jax.Array):
    # Calculate current weight for each batch element
    current_weights = (state * weights[None, ...]).sum(axis=1, keepdims=True)
    can_fit = (current_weights + weights[None, :]) <= max_weight
    under_limit = state < limit

    # Combine masks
    fmask = (under_limit & can_fit).astype(jnp.float32)

    # Stop if no valid moves remain
    fstopped = fmask.sum(axis=1) == 0

    bmask = (state >= 1).astype(jnp.float32)
    bstopped = (state == 0).all(axis=1)

    return fstopped, bstopped, fmask, bmask


# forward apply
def fapply(state: EnvState, actions: jax.Array) -> EnvState:
    active = ~state.fstopped
    update = jnp.where(active, 1, 0)

    new_state = state.state.at[state.batch_ids, actions].add(update)

    fstopped, bstopped, fmask, bmask = umask(new_state, state.max_weight, state.limit, state.weights)

    return state.replace(
        state=new_state,
        fmask=fmask,
        bmask=bmask,
        fstopped=fstopped,
        bstopped=bstopped,
    )


# backward apply
def bapply(state: EnvState, actions: jax.Array) -> EnvState:
    active = ~state.bstopped
    update = jnp.where(active, -1, 0)

    new_state = state.state.at[state.batch_ids, actions].add(update)

    # We remote the added items via the mask.
    fstopped, bstopped, fmask, bmask = umask(new_state, state.max_weight, state.limit, state.weights)

    return state.replace(
        state=new_state,
        fmask=fmask,
        bmask=bmask,
        fstopped=fstopped,
        bstopped=bstopped,
    )


def bpol(state: EnvState, _: GFlowNet = None):
    logits = jnp.where(state.bmask == 1, 1, -jnp.inf)
    logits = nnx.log_softmax(logits, axis=1)
    return logits


def fpol(state: EnvState, gfn: GFlowNet):
    logits = gfn(state)
    logits = jnp.where(state.fmask == 1, logits, -jnp.inf)
    logits = nnx.log_softmax(logits, axis=-1)
    return logits


def fstep(carry: tuple[EnvState, jax.Array], _, gfn: nnx.Module, eps: float = 0.0):
    state, key = carry
    factive = ~state.fstopped

    # Sample actions
    flogits, actions, key = sample_with_mask(gfn, state, key, mask=state.fmask, eps=eps)
    # Move to the forward state
    fstate = fapply(state, actions)

    # Compute the backward (uniform) probabilities
    blogits = -jnp.log(fstate.bmask.sum(axis=1))

    flogits = jnp.where(factive, flogits, 0)
    blogits = jnp.where(factive, blogits, 0)

    return (fstate, key), (flogits, blogits, actions)


def bstep(carry: tuple[EnvState, jax.Array], _, gfn: nnx.Module):
    # We simply sample from a backward distribution over the unmasked states
    state, key = carry

    # Sample from categorical
    key, subkey = jax.random.split(key, 2)
    blogits = bpol(state)
    actions = jax.random.categorical(subkey, blogits)

    blogits = take(blogits, actions)
    bactive = ~state.bstopped

    # Apply to the state
    bstate = bapply(state, actions)

    # Compute the forward probabilities
    flogits = fpol(bstate, gfn)
    flogits = take(flogits, actions)

    flogits = jnp.where(bactive, flogits, 0)
    blogits = jnp.where(bactive, blogits, 0)

    return (bstate, key), (flogits, blogits, actions)


def make_output_dir(
    num_items: int,
    max_weight: int,
    limit: int,
    seed: int,
    log_reward_type: LogRewardType,
) -> pathlib.Path:
    dir_name = f"items{num_items}_maxw{max_weight}_limit{limit}_seed{seed}_{log_reward_type.value}"
    output_dir = pathlib.Path("outputs") / dir_name
    output_dir.mkdir(parents=True, exist_ok=True)
    return output_dir


@app.command()
def main(
    num_items: int,
    limit: int,
    max_weight: int,
    iterations: int = 512,  # number of iterations
    bs: int = 64,  # batch size
    dmid: int = 128,  # hidden dim
    nlayers: int = 2,  # number of layers
    seed: int = 42,
    reward_seed: int = 43,
    logr_t: LogRewardType = LogRewardType.UTILITY,
) -> MainOutput:
    config = Config(
        din=num_items,
        dout=num_items,
        iterations=iterations,
        bs=bs,
        dmid=dmid,
        nlayers=nlayers,
        key=jax.random.key(seed),
    )
    env = EnvConfig(
        fstep=fstep,
        bstep=bstep,
        fapply=fapply,
        fpol=fpol,
        bpol=bpol,
    )
    output_dir = make_output_dir(num_items, max_weight, limit, seed, logr_t)

    key, subkey = jax.random.split(config.key, 2)
    s_o = create_state(max_weight, num_items, limit, config.bs, key=subkey)

    logr = LogReward(num_items, s_o.weights, seed=reward_seed)
    queue = create_queue(num_items, capacity=int(200))
    config = config.replace(key=key)

    out_div = run_div_gfn(config.replace(bs=config.bs // 2), env, logr, s_o, queue)
    out = run_gfn(config, env, logr, s_o, queue)
    out_random = run_random_sampler(config, env, logr, s_o, queue)
    out_teacher = run_teacher_gfn(config, env, logr, s_o, queue)
    out_sa = run_sa_gfn(config.replace(bs=config.bs // 2), env, logr, s_o, queue)

    # We should also implement other cooperative techniques
    iterations_for_fcs = 32

    gfn_div_fcs = fcs(
        out_div.s_o,
        out_div.gfn,
        logr=logr,
        key=key,
        iterations=iterations_for_fcs,
        env=env,
    )

    gfn_fcs = fcs(
        out.s_o,
        out.gfn,
        logr=logr,
        key=key,
        iterations=iterations_for_fcs,
        env=env,
    )

    gfn_random_fcs = fcs(
        out_random.s_o,
        out_random.gfn,
        logr=logr,
        key=key,
        iterations=iterations_for_fcs,
        env=env,
    )

    gfn_teacher_fcs = fcs(
        out_teacher.s_o,
        out_teacher.gfn,
        logr=logr,
        key=key,
        iterations=iterations_for_fcs,
        env=env,
    )

    gfn_sa_fcs = fcs(
        out_sa.s_o,
        out_sa.gfn,
        logr=logr,
        key=key,
        iterations=iterations_for_fcs,
        env=env,
    )
    jax.debug.print(
        "div {} tb {} random {} teacher {} sa {}",
        gfn_div_fcs,
        gfn_fcs,
        gfn_random_fcs,
        gfn_teacher_fcs,
        gfn_sa_fcs,
    )

    # Save to `output_dir`
    save_and_plot_results(
        output_dir,
        out_div,
        out,
        out_random,
        out_teacher,
        out_sa,
        gfn_div_fcs,
        gfn_fcs,
        gfn_random_fcs,
        gfn_teacher_fcs,
        gfn_sa_fcs,
    )

    return MainOutput(
        std=out,
        div=out_div,
        teacher=out_teacher,
        sa=out_sa,
    )


if __name__ == "__main__":
    app()
