import os
import pathlib
import pickle
from enum import Enum

import flax.nnx as nnx
import flax.struct as struct
import jax
import jax.numpy as jnp
import jax.scipy as jsp
import typer

from divgfn.eval_utils import (
    Config,
    EnvConfig,
    MainOutput,
    fcs,
    run_div_gfn,
    run_gfn,
    run_random_sampler,
    run_sa_gfn,
    run_teacher_gfn,
)
from divgfn.policies import GFlowNet, sample, take
from divgfn.utils import create_queue, save_and_plot_results

app = typer.Typer(pretty_exceptions_enable=False)

HERE = os.path.dirname(os.path.abspath(__file__))


class LogRewardType(Enum):
    SIMPLE = "simple"
    TFN = "tfn"
    BITS = "bits"


@struct.dataclass
class EnvState:
    L: int = struct.field(pytree_node=False)  # Sequence length
    vocab_size: int = struct.field(pytree_node=False)  # Vocabulary size
    state: jax.Array  # Sequence
    last_idx: jax.Array

    batch_ids: jax.Array

    fstopped: jax.Array
    bstopped: jax.Array


class SimpleLogReward(nnx.Module):
    def __init__(self, seq_size: int, vocab_size: int, key: jax.Array, t: float = 1):
        self.seq_size = seq_size
        self.vocab_size = vocab_size

        self.t = t
        key, subkey = jax.random.split(key, 2)
        self.u = jax.random.uniform(key=key, shape=(self.vocab_size,))
        self.v = jax.random.uniform(key=subkey, shape=(self.seq_size,))

    def __call__(self, state: EnvState):
        x = state.state
        indices = (x - 1).astype(jnp.int32)
        mask = (x != 0).astype(bool)
        log_r_u = jnp.where(mask, self.u[indices], 0).sum(axis=1)
        log_r_v = jnp.where(mask, self.v[None, ...], 0).sum(axis=1)
        log_r = log_r_u + log_r_v
        return log_r / self.t

    @property
    def logz(self):
        # We compute the k-th power of the utility sums
        v_sum = jax.nn.logsumexp(self.v, axis=0) / self.t
        u_sum = jax.nn.logsumexp(self.u, axis=0) / self.t
        return v_sum + self.seq_size * u_sum


class LogRewardTFN(nnx.Module):
    def __init__(self, seq_size: int, max_val: float = 10, exp: float = 3):
        self.seq_size = seq_size
        self.max_val = max_val
        self.exp = exp

        assert self.seq_size in [8, 10]

        self.sequence_to_idx = (
            lambda sequence: (sequence * 4 ** jnp.arange(self.seq_size))
            .sum(axis=1)
            .astype(
                jnp.int32,
            )
        )

        with open(f"{HERE}/datasets/tfbind{self.seq_size}-exact-v0-all.pkl", "rb") as f:
            data = pickle.load(f)

        states, rewards = data["x"], data["y"]

        if self.seq_size == 10:
            rewards = jsp.special.expit(rewards)

        states = jnp.array(states)
        self.scaled_oracle = jnp.asarray(
            max_val * (rewards**exp) / jnp.max(rewards**exp),
        )
        indices = self.sequence_to_idx(states)

        self.scaled_oracle = self.scaled_oracle[indices].squeeze(axis=1)

    def __call__(self, state: EnvState):
        indices = self.sequence_to_idx(state.state - 1)
        return jnp.log(self.scaled_oracle[indices] + 1e-8)

    @property
    def logz(self):
        return jax.nn.logsumexp(jnp.log(self.scaled_oracle), axis=0)


class LogRewardBits(nnx.Module):
    def __init__(self, seq_size: int, num_modes: int, key: jax.Array, t: float = 5e-2):
        super().__init__()
        self.seq_size = seq_size
        self.num_modes = num_modes
        self.t = t

        # Sample modes uniformly at random
        mode_components = jnp.array(
            [
                [0, 0, 0, 0, 0, 0, 0, 0],  # '00000000'
                [1, 1, 1, 1, 1, 1, 1, 1],  # '11111111'
                [1, 1, 1, 1, 0, 0, 0, 0],  # '11110000'
                [0, 0, 0, 0, 1, 1, 1, 1],  # '00001111'
                [0, 0, 1, 1, 1, 1, 0, 0],  # '00111100'
            ],
            dtype=jnp.int32,
        )

        num_components = seq_size // 8
        indices = jax.random.randint(
            key,
            shape=(num_modes, num_components),
            minval=0,
            maxval=len(mode_components),
        )

        modes = mode_components[indices].reshape(num_modes, -1)

        # Pad if necessary
        padding_size = seq_size - modes.shape[1]
        if padding_size > 0:
            modes = jnp.concatenate(
                [modes, jnp.zeros((num_modes, padding_size), dtype=modes.dtype)],
                axis=1,
            )

        self.modes = modes

    def __call__(self, state: EnvState):
        state = state.state.astype(jnp.float32) - 1

        # Compute distance to all modes
        distance_to_modes = jnp.abs(
            state[:, None, :] - self.modes[None, :, :],
        ).sum(axis=2)

        # assert distance_to_modes.shape[1] == self.num_modes, distance_to_modes.shape
        logr = 1 - distance_to_modes.min(axis=1) / self.seq_size
        return logr / self.t


def create_state(seq_size: int, vocab_size: int, bs: int):
    # Create state with batch size bs
    state = EnvState(
        L=seq_size,
        vocab_size=vocab_size,
        state=jnp.zeros((bs, seq_size)),
        last_idx=jnp.zeros((bs,), dtype=jnp.int32),
        batch_ids=jnp.arange(bs),
        fstopped=jnp.zeros((bs,), dtype=jnp.bool),
        bstopped=jnp.ones((bs,), dtype=jnp.bool),
    )
    return state


def fpol(state: EnvState, gfn: GFlowNet):
    logits = gfn(state)
    return logits


def bpol(state: EnvState, _: GFlowNet = None):
    bs, _ = state.state.shape
    return jnp.zeros((bs, 1))


def fapply(state: EnvState, actions: jax.Array):
    fstate = state.state.at[state.batch_ids, state.last_idx].set(actions + 1)
    idx = state.last_idx + 1
    new_state = state.replace(
        state=fstate,
        last_idx=idx,
        fstopped=(idx >= state.L),
        bstopped=(idx <= 0),
    )
    return new_state


def bapply(state: EnvState):
    idx = state.last_idx - 1
    new_state = state.replace(
        state=state.state.at[state.batch_ids, idx].set(0),
        last_idx=idx,
        fstopped=(idx >= state.L),
        bstopped=(idx <= 0),
    )
    return new_state


def fstep(carry: tuple[EnvState, jax.Array], _, gfn: nnx.Module, eps: float = 0.0):
    state, key = carry
    # Sample actions
    flogits, actions, key = sample(gfn, state, key, eps=eps)
    new_state = fapply(state, actions)
    blogits = bpol(new_state)[:, 0]

    return (new_state, key), (flogits, blogits, actions)


def bstep(carry: tuple[EnvState, jax.Array], _, gfn: nnx.Module):
    state, key = carry
    # Sample actions
    new_state = bapply(state)

    # Compute the probability of giong from new_state -> state
    actions = state.state[state.batch_ids, new_state.last_idx].astype(jnp.int32)
    actions = actions - 1
    flogits = fpol(new_state, gfn)
    flogits = take(flogits, actions)
    blogits = bpol(state)[:, 0]

    return (new_state, key), (flogits, blogits, actions)


def make_output_dir(seq_size: int, vocab_size: int, seed: int, log_reward_type: LogRewardType) -> pathlib.Path:
    dir_name = f"seq{seq_size}_vocab{vocab_size}_seed{seed}_{log_reward_type.value}"
    output_dir = pathlib.Path("outputs") / dir_name
    output_dir.mkdir(parents=True, exist_ok=True)
    return output_dir


@app.command()
def main(
    seq_size: int,
    vocab_size: int,
    iterations: int = 512,  # number of iterations
    bs: int = 64,  # batch size
    dmid: int = 128,  # hidden dim
    nlayers: int = 2,  # number of layers
    seed: int = 42,
    logr_t: LogRewardType = LogRewardType.SIMPLE,
    reward_seed: int = 43,
):
    config = Config(
        din=seq_size,
        dout=vocab_size,
        iterations=iterations,
        bs=bs,
        dmid=dmid,
        nlayers=nlayers,
        key=jax.random.key(seed),
    )
    env = EnvConfig(
        fstep=fstep,
        bstep=bstep,
        fpol=fpol,
        bpol=bpol,
        fapply=fapply,
    )
    output_dir = make_output_dir(seq_size, vocab_size, seed, logr_t)

    if logr_t == LogRewardType.TFN:
        assert vocab_size == 4
        reward_key = jax.random.key(reward_seed)
        logr = LogRewardTFN(seq_size=seq_size, key=reward_key)
        jax.debug.print("{}", logr.logz)
    elif logr_t == LogRewardType.BITS:
        reward_key = jax.random.key(reward_seed)
        assert vocab_size == 2
        logr = LogRewardBits(seq_size, num_modes=1024, key=reward_key)
    else:
        reward_key = jax.random.key(reward_seed)
        logr = SimpleLogReward(
            seq_size=seq_size,
            vocab_size=vocab_size,
            key=reward_key,
        )
        jax.debug.print("{}", logr.logz)

    queue = create_queue(seq_size, capacity=int(200))
    s_o = create_state(seq_size, vocab_size, bs)

    out_div = run_div_gfn(config.replace(bs=config.bs // 2), env, logr, s_o, queue)
    out = run_gfn(config, env, logr, s_o, queue)
    out_random = run_random_sampler(config, env, logr, s_o, queue)
    out_teacher = run_teacher_gfn(config, env, logr, s_o, queue)
    out_sa = run_sa_gfn(config.replace(bs=config.bs // 2), env, logr, s_o, queue)
    # We should also implement other cooperative techniques
    key = config.key

    gfn_div_fcs = fcs(
        out_div.s_o,
        out_div.gfn,
        logr=logr,
        key=key,
        env=env,
        iterations=512,
    )

    gfn_fcs = fcs(
        out.s_o,
        out.gfn,
        logr=logr,
        key=key,
        env=env,
        iterations=512,
    )

    gfn_random_fcs = fcs(
        out_random.s_o,
        out_random.gfn,
        logr=logr,
        key=key,
        env=env,
        iterations=512,
    )

    gfn_teacher_fcs = fcs(
        out_teacher.s_o,
        out_teacher.gfn,
        logr=logr,
        key=key,
        env=env,
        iterations=512,
    )

    gfn_sa_fcs = fcs(
        out_sa.s_o,
        out_sa.gfn,
        logr=logr,
        key=key,
        env=env,
        iterations=512,
    )
    jax.debug.print(
        "div {} tb {} random {} teacher {} sa {}",
        gfn_div_fcs,
        gfn_fcs,
        gfn_random_fcs,
        gfn_teacher_fcs,
        gfn_sa_fcs,
    )

    # Save to `output_dir`
    save_and_plot_results(
        output_dir,
        out_div,
        out,
        out_random,
        out_teacher,
        out_sa,
        gfn_div_fcs,
        gfn_fcs,
        gfn_random_fcs,
        gfn_teacher_fcs,
        gfn_sa_fcs,
    )

    return MainOutput(
        std=out,
        div=out_div,
        teacher=out_teacher,
        sa=out_sa,
    )


if __name__ == "__main__":
    app()
