import itertools
import os
import pickle
from functools import partial
from typing import Self

import jax
import jax.numpy as jnp
import jax.random as random
from flax import struct
from flax import nnx

from nais.gflownet import GFlowNet

from nais.gym.base import (
    Environment,
    EnvironmentConfig,
    LogRewardBase,
    LogRewardConfig,
    EnvState,
)

HERE = os.path.dirname(os.path.abspath(__file__))


# This closely resembles the representation we use in the Sets environment
# @struct.dataclass
# class SequenceState(struct.PyTreeNode):
#     state: jax.Array
#     size: jax.Array
#     forward_mask: jax.Array
#     backward_mask: jax.Array
#     stopped: jax.Array
#     is_initial: jax.Array


def _get_action_from_index(
    index: jax.Array, vocab_size: int, ngram_size: int
) -> jax.Array:
    base = jnp.logspace(
        start=1, stop=ngram_size, base=vocab_size, num=ngram_size, dtype=index.dtype
    )
    # base = (vocab_size, vocab_size**2, ...)
    item = jnp.mod(index[:, None], base)  # (b, 1) -> (b, ngram_size)
    item_to_action = jnp.eye(ngram_size) - jnp.eye(
        ngram_size, k=1
    )  # (ngram_size, ngram_size)
    actions = item @ item_to_action  # (b, ngram_size)
    actions = actions / (base / vocab_size)
    return actions.astype(index.dtype)


def _get_index_from_action(
    action: jax.Array, vocab_size: int, ngram_size: int
) -> jax.Array:
    # the index for an action (a1, a2, ..., an)
    # is given by a1 + a2 * vocab_size + ... + an * vocab_size**(ngram_size - 1)
    base = jnp.logspace(
        start=0,
        stop=ngram_size - 1,
        base=vocab_size,
        num=ngram_size,
        dtype=action.dtype,
    )  # (ngram_size,)
    # action = (b, ngram_size)
    index = action @ base  # (b,)
    return index.astype(jnp.int32)


def append(
    env_state: EnvState,
    actions: jax.Array,
    active_mask: jax.Array,
    batch_ids: jax.Array,
    seq_size: int,
    vocab_size: int,
    ngram_size: int,
) -> EnvState:
    current_size = (env_state.state != 0).sum(axis=1)
    safe_indices = jnp.where(active_mask, current_size, 0)

    # We include the offset to account for the ngram size
    safe_indices_with_offset = jnp.expand_dims(safe_indices, axis=1) + jnp.arange(
        ngram_size
    )  # (b, ngram_size)
    batch_ids_with_offset = jnp.expand_dims(batch_ids, axis=1)
    active_mask_with_offset = jnp.expand_dims(active_mask, axis=1)

    current_values = env_state.state[batch_ids_with_offset, safe_indices_with_offset]

    # We first convert the action in the interval [0, vocab_size**ngram_size] (size: (b,))
    # into an token sequence in the vocabulary with size (b, ngram_size).

    # actions should start at 1 and proceed sequentially
    updates = jnp.where(
        active_mask_with_offset,
        _get_action_from_index(actions, vocab_size, ngram_size) + 1,
        current_values,
    )
    state = env_state.state.at[batch_ids_with_offset, safe_indices_with_offset].set(
        updates
    )
    size = jnp.where(active_mask, current_size + ngram_size, current_size)

    stopped = jnp.where(
        active_mask,
        (size == seq_size).astype(env_state.stopped.dtype),
        env_state.stopped,
    )
    is_initial = jnp.where(
        active_mask,
        (size == 0).astype(env_state.is_initial.dtype),
        env_state.is_initial,
    )

    backward_mask = jnp.where(
        active_mask_with_offset,
        jnp.expand_dims(size > 0, axis=1).astype(env_state.backward_mask.dtype),
        env_state.backward_mask,
    )

    return env_state.replace(
        state=state,
        stopped=stopped,
        is_initial=is_initial,
        backward_mask=backward_mask,
    )


def apply_fn(gflownet_key: tuple[GFlowNet, jax.Array], _):
    gflownet, key = gflownet_key

    env_state = gflownet.state.env_state
    is_active = env_state.stopped == 0.0

    out_pf = gflownet.pf.sample_actions(env_state, key=key)

    seq_size = env_state.state.shape[-1]
    ngram_size = seq_size // env_state.max_trajectory_length
    vocab_size = int(env_state.forward_mask.shape[-1] ** (1 / ngram_size))

    env_state = append(
        env_state=env_state,
        actions=out_pf.actions,
        active_mask=is_active,
        batch_ids=env_state.batch_ids,
        seq_size=seq_size,
        vocab_size=vocab_size,
        ngram_size=ngram_size,
    )

    log_pf = gflownet.state.log_pf.at[:, gflownet.state.idx].set(
        jnp.where(is_active, out_pf.log_pf, 0.0)
    )

    gflownet.state = gflownet.state.replace(
        env_state=env_state, log_pf=log_pf, idx=gflownet.state.idx + 1
    )

    return (gflownet, out_pf.key), (jnp.zeros_like(out_pf.actions), is_active, env_state)


def remove(
    env_state: EnvState,
    active_mask: jax.Array,
    batch_ids_with_offset: jax.Array,
    current_size: jax.Array,
    ngram_size: int,
) -> EnvState:
    safe_indices = jnp.where(active_mask, current_size, 0)
    safe_indices_with_offset = (
        jnp.expand_dims(safe_indices, axis=1) - jnp.arange(ngram_size) - 1
    )
    active_mask_with_offset = jnp.expand_dims(active_mask, axis=1)

    current_values = env_state.state[batch_ids_with_offset, safe_indices_with_offset]
    updates = jnp.where(active_mask_with_offset, 0.0, current_values)

    state = env_state.state.at[batch_ids_with_offset, safe_indices_with_offset].set(
        updates
    )

    size = jnp.where(active_mask, current_size - ngram_size, current_size)

    stopped = jnp.where(active_mask, 0, env_state.stopped)
    is_initial = jnp.where(
        active_mask,
        (size == 0).astype(env_state.is_initial.dtype),
        env_state.is_initial,
    )
    backward_mask = jnp.where(
        active_mask_with_offset,
        jnp.expand_dims(size > 0, axis=1).astype(env_state.backward_mask.dtype),
        env_state.backward_mask,
    )

    return env_state.replace(
        state=state,
        stopped=stopped,
        is_initial=is_initial,
        backward_mask=backward_mask,
    )


def backward_fn(gflownet_key: tuple[GFlowNet, jax.Array], _):
    # Backward actions for sequences do not need to use the key
    gflownet, key = gflownet_key

    env_state = gflownet.state.env_state
    is_active = env_state.is_initial == 0.0
    # trajectory_length = seq_size / ngram_size, i.e., ngram_size = seq_size / trajectory_length
    ngram_size = env_state.state.shape[-1] // env_state.max_trajectory_length
    vocab_size = int(env_state.forward_mask.shape[-1] ** (1 / ngram_size))

    current_size = jnp.sum(env_state.state != 0, axis=1)

    def slice_row(row, size):
        start = jnp.maximum(size - ngram_size, 0)
        return jax.lax.dynamic_slice_in_dim(row, start, ngram_size, axis=0)

    # current_size: shape (batch_size,)
    forward_actions = jax.vmap(slice_row, in_axes=(0, 0))(env_state.state, current_size)

    env_state = remove(
        env_state=env_state,
        active_mask=is_active,
        batch_ids_with_offset=env_state.batch_ids[..., None],
        current_size=current_size,
        ngram_size=ngram_size,
    )

    gflownet.state = gflownet.state.replace(
        env_state=env_state, log_pb=jnp.zeros_like(gflownet.state.log_pb)
    )

    actions = _get_index_from_action(
        forward_actions - 1, vocab_size=vocab_size, ngram_size=ngram_size
    )

    return (gflownet, key), (actions, is_active, env_state)


def factory(
    seq_size: int, vocab_size: int, config: EnvironmentConfig, *, ngram_size: int = 1
):
    return EnvState(
        state=jnp.zeros((config.batch_size, seq_size)),
        forward_mask=jnp.ones((config.batch_size, vocab_size**ngram_size)),
        backward_mask=jnp.zeros((config.batch_size, 1)),
        batch_ids=jnp.arange(config.batch_size),
        stopped=jnp.zeros((config.batch_size,)),
        is_initial=jnp.ones((config.batch_size,)),
        max_trajectory_length=seq_size // ngram_size,
        batch_size=config.batch_size,
    )


# # We will start with fixed-length, forwardly-constructed sequences.
# # Then we might proceed to variable-length, forward-backwardly-constructed sequences.
# class Sequences(Environment):
#     state: jax.Array

#     def __init__(
#         self,
#         seq_size: int,
#         vocab_size: int,
#         config: EnvironmentConfig,
#         *,
#         ngram_size: int = 1,
#     ):
#         super().__init__(config=config)
#         self.seq_size = seq_size
#         self.vocab_size = vocab_size
#         self.ngram_size = ngram_size

#         # Each action corresponds to a `self.ngram_size`-sized string of the vocabulary
#         self.num_actions = self.vocab_size**self.ngram_size

#         # We should ensure that seq_size is a multiple of ngram_size
#         assert seq_size % ngram_size == 0, "seq_size must be a multiple of ngram_size"

#         self._state = SequenceState(
#             state=jnp.zeros((self.batch_size, self.seq_size)),
#             size=jnp.zeros((self.batch_size,), dtype=jnp.int32),
#             forward_mask=jnp.ones((self.batch_size, self.num_actions)),
#             backward_mask=jnp.zeros((self.batch_size, 1)),
#             stopped=self.stopped,
#             is_initial=self.is_initial,
#         )

#         self.max_trajectory_length = seq_size

#         self._sync_views()

#     def _sync_views(self):
#         self.state = self._state.state
#         self.size = self._state.size
#         self.forward_mask = self._state.forward_mask
#         self.backward_mask = self._state.backward_mask
#         self.stopped = self._state.stopped
#         self.is_initial = self._state.is_initial

#     def apply(self, actions: jax.Array, active_mask: jax.Array | None = None) -> jax.Array:
#         active_mask = self._active_mask(active_mask)
#         self._state = _apply_impl(
#             self._state,
#             actions,
#             active_mask,
#             self.batch_ids,
#             self.seq_size,
#             self.vocab_size,
#             self.ngram_size,
#         )
#         self._sync_views()
#         return actions

#     def backward(self, actions: jax.Array, active_mask: jax.Array | None = None) -> jax.Array:
#         active_mask = self._active_mask(active_mask)
#         batch_ids_with_offset = jnp.expand_dims(self.batch_ids, axis=1)

#         indices = jnp.expand_dims(self.size, axis=1) - jnp.arange(self.ngram_size)

#         forward_actions = (self.state[batch_ids_with_offset, indices].astype(jnp.int32)) - 1  # -1 because the actions start at 1

#         self._state = _backward_impl(
#             self._state,
#             active_mask,
#             batch_ids_with_offset,
#             self.ngram_size,
#         )
#         self._sync_views()
#         return _get_index_from_action(forward_actions, self.vocab_size, self.ngram_size)

#     def fwd_to_bcw_actions(self, actions: jax.Array):
#         return jnp.zeros_like(actions)

#     def merge(self, batch_state: Self):
#         super().merge(batch_state)
#         self._state = SequenceState(
#             state=self.state,
#             size=self.size,
#             forward_mask=self.forward_mask,
#             backward_mask=self.backward_mask,
#             stopped=self.stopped,
#             is_initial=self.is_initial,
#         )

#     @property
#     def log_space_size(self):
#         return self.seq_size * jnp.log(self.vocab_size)


class LogReward(LogRewardBase):
    def __init__(self, seq_size: int, vocab_size: int, config: LogRewardConfig):
        super().__init__(config)
        self.seq_size = seq_size
        self.vocab_size = vocab_size

        # Log utilities for vocabulary items
        self.log_utilities = random.normal(key=random.key(42), shape=(vocab_size,))

        # Log utilities for positioning items
        self.log_position_utilities = random.normal(
            key=random.key(42), shape=(seq_size,)
        )

    def __call__(self, state: EnvState):
        log_position_utilities = jnp.where(
            state.state > 0, self.log_position_utilities, 0.0
        )  # (b, n)
        state_indices = state.state.astype(jnp.int32) - 1
        log_utilities = jnp.where(
            state_indices > 0, self.log_utilities[state_indices], 0.0
        )  # (b, n)
        return (log_position_utilities + log_utilities).sum(
            axis=1
        ) / self.temperature  # (b,)


@partial(jax.jit, static_argnums=(1,))
def _sequence_to_idx(sequence: jax.Array, seq_size: int) -> jax.Array:
    # This maps a sequence into its position
    # in the lexicographic ordering of the possible sequences
    index_exp = 4 ** jnp.arange(seq_size)
    sequence = jnp.flip(sequence, axis=1)
    index = (sequence * index_exp).sum(axis=1)
    return index.astype(jnp.int32)


class LogRewardTFN(LogRewardBase):
    def __init__(self, n, config: LogRewardConfig, *, max_val=10, exp=3):
        super().__init__(config)
        self.seq_size = n
        self.max_val = max_val
        self.exp = exp

        assert self.seq_size in [8, 10]

        with open(f"{HERE}/../../datasets/tfbind{n}-exact-v0-all.pkl", "rb") as f:
            oracle_d = pickle.load(f)

        states, rewards = oracle_d["x"], oracle_d["y"]

        if n == 10:
            from scipy.special import expit

            rewards = expit(rewards * 3)

        states = jnp.array(states)

        self.scaled_oracle = jnp.array(
            max_val * (rewards**exp) / max(rewards**exp)
        ).squeeze(axis=-1)

        indices = _sequence_to_idx(states, self.seq_size)

        self.scaled_oracle = self.scaled_oracle[indices]

    def __call__(self, batch_state: EnvState):
        indices = _sequence_to_idx(batch_state.state - 1, self.seq_size)
        return jnp.log(self.scaled_oracle[indices] + 1e-3) / self.temperature

    @property
    def mode_th(self):
        return jnp.quantile(jnp.log(self.scaled_oracle), q=0.995) / self.temperature


def sim_likelihood(logits: jax.Array, dim: int, max_w: int):
    ll = nnx.sigmoid(logits)
    eps = 1e-6
    return jnp.clip(ll, eps, 1 - eps)


def sim_pref_data(size: int, dim: int, w: jax.Array, max_w: float, key: int = 42):
    # We first generate normally distributed data
    X = jax.random.choice(a=2, key=random.key(key), shape=(size, dim))
    X = X.astype(jnp.float32)

    # We then compute the pairwise differences in a randomly selected subset
    pairs = [(i, j) for (i, j) in itertools.combinations(range(size), 2)]
    indices = jax.random.choice(
        a=jnp.array(pairs),
        key=random.key(key),
        shape=(2 * size,),
        replace=False,
    )

    dataset = []
    for a, b in indices:
        logits = jnp.dot(X[a], w) - jnp.dot(X[b], w)
        p = sim_likelihood(logits, dim, max_w)
        y = jax.random.bernoulli(key=random.key(key), p=p)
        y = y.astype(jnp.float32)
        instance = jnp.hstack((X[a], X[b], y))
        # This is perhaps not needed, but the code becomes clearer
        # when this is included
        instance = jnp.expand_dims(instance, axis=0)
        dataset.append(instance)

    return jnp.vstack(dataset)


class LogRewardPreferences(LogRewardBase):
    #  python examples/train_compiled.py --env sequences --iterations 1024 --hidden-size 32 --batch-size 32 --num-trajectories 64 --lr 1e-3 --criterion cb --log-reward-type seqs-preferences
    def __init__(
        self,
        config: LogRewardConfig,
        *,
        min_w: int = -4,
        max_w: int = 4,
        dim: int = 10,
        size: int = 200,
        seed: int = 42,
    ):
        super().__init__(config)
        self.min_w = min_w
        self.max_w = max_w
        self.dim = dim
        self.size = size
        self.seed = seed

        self.w = jax.random.randint(
            key=random.key(seed),
            minval=min_w,
            maxval=max_w,
            shape=(dim,),
            dtype=jnp.int32,
        )

        self.Ws = jnp.linspace(min_w, max_w, num=max_w - min_w + 1, endpoint=True)

        self.dataset = sim_pref_data(size, dim, self.w, self.max_w, self.seed)
        self.X1, self.X2 = jnp.split(self.dataset[:, :-1], 2, axis=1)
        self.y = self.dataset[:, -1]

        # Standard 80/20 splitting
        self.dataset_heldout = sim_pref_data(
            size, dim, self.w, self.max_w, self.seed + 1
        )

        self.X1_heldout, self.X2_heldout = jnp.split(
            self.dataset_heldout[:, :-1], 2, axis=1
        )
        self.y_heldout = self.dataset_heldout[:, -1]

    def __call__(self, state: EnvState) -> jax.Array:
        # We use an uniform prior over [minval, maxval],
        # so we can compute the log reward only through the Bernoulli likelihood

        items = (state.state - 1).astype(jnp.int32)
        w = self.Ws[items]  # (b, dim)

        log_prior = -self.dim * jnp.log(self.max_w - self.min_w + 1)

        logits = w @ (self.X1 - self.X2).T  # (b, size)
        p = sim_likelihood(logits, self.dim, self.max_w)
        log_p = self.y * jnp.log(p) + (1 - self.y) * jnp.log(1 - p)  # (b, size)
        return (log_p.sum(axis=1) + log_prior) / self.temperature

    def get_heldout_likelihood(self, state: EnvState) -> jax.Array:
        items = (state.state - 1).astype(jnp.int32)
        w = self.Ws[items]
        logits = w @ (self.X1_heldout - self.X2_heldout).T
        p = sim_likelihood(logits, self.dim, self.max_w)

        p = self.y_heldout * p + (1 - self.y_heldout) * (1 - p)
        return jnp.log(p).sum(axis=1)  # sum over samples


def get_predictive_marginal(
    gflownet: GFlowNet,
    log_reward: LogRewardPreferences,
    key: jax.Array,
    num_samples: int,
):
    apply_jitted = nnx.jit(apply_fn)

    log_rewards = jnp.zeros((0,))
    gflownet.set_policy_eps(0.0)
    for _ in range(num_samples):
        (res, _), _ = jax.lax.scan(
            apply_jitted,
            init=(gflownet, key),
            length=gflownet.state.env_state.max_trajectory_length,
        )
        log_rewards_for_sample = log_reward.get_heldout_likelihood(
            res.state.env_state
        )  # p(D | env_state)
        log_rewards = jnp.hstack([log_rewards, log_rewards_for_sample])
    gflownet.reset_policy_eps()

    return jax.nn.logsumexp(log_rewards, axis=0) - jnp.log(
        num_samples * gflownet.state.env_state.batch_size
    )



class LogRewardBits(LogRewardBase): 

    def __init__(self, config: LogRewardConfig, *, src_size: int, seq_size: int, num_modes: int, key: jax.Array):
        super(LogRewardBits, self).__init__(config) 
        self.src_size = src_size 
        self.seq_size = seq_size 
        self.num_modes = num_modes 

        # Sample modes uniformly at random 
        mode_components = jnp.array([
            [0, 0, 0, 0, 0, 0, 0, 0],  # '00000000'
            [1, 1, 1, 1, 1, 1, 1, 1],  # '11111111'
            [1, 1, 1, 1, 0, 0, 0, 0],  # '11110000'
            [0, 0, 0, 0, 1, 1, 1, 1],  # '00001111'
            [0, 0, 1, 1, 1, 1, 0, 0],  # '00111100'
        ], dtype=int) 
        
        num_components = seq_size // 8 
        indices = jax.random.randint(
            minval=0, maxval=len(mode_components), shape=(num_modes, num_components), key=key,
        )

        self.modes = mode_components[indices]
        self.modes = self.modes.reshape(-1, num_components * 8)  
        self.modes = jnp.hstack(
            [
                self.modes, 
                jnp.zeros((num_modes, seq_size - self.modes.shape[1]), dtype=self.modes.dtype)
            ] 
        )
        
        # pass 

    def __call__(self, batch_state: EnvState): 
        state = batch_state.state
        distance_to_modes = jnp.abs(
            state.reshape(-1, 1, self.seq_size) - self.modes.reshape(1, -1, self.seq_size)
        ).sum(axis=2) 
        assert distance_to_modes.shape[1] == self.num_modes, distance_to_modes.shape 
        return 1 - distance_to_modes.min(axis=1) / self.seq_size 

if __name__ == "__main__":
    import jax.numpy as jnp
    import jax.random as random

    size = 100
    dim = 10
    w = jnp.array([0.1] * dim)

    dataset = sim_pref_data(size, dim, w)
    print(dataset.shape)
