from copy import deepcopy
from dataclasses import dataclass
from typing import Generic, Optional, TypeVar

import flax.nnx as nnx
import flax.struct as struct
import jax
import jax.numpy as jnp
import optax

from nais.gym.base import Environment, EnvState, LogRewardBase
from nais.policies.base import BackwardPolicyBase, FlowFunctionBase, ForwardPolicyBase

T = TypeVar("T", bound=Environment)


def huber_loss(diff: jax.Array, delta: float = 9.0):
    diff_sq = diff * diff
    larger_than_delta = (diff_sq > delta**2).astype(dtype=diff.dtype)
    return (delta * (diff_sq - 0.5 * delta) * larger_than_delta + 0.5 * diff_sq * (1 - larger_than_delta)).mean()


def _subtraj_loss(
    i: jax.Array,
    log_pf: jax.Array,
    log_pb: jax.Array,
    log_F: jax.Array,
    last_idx: jax.Array,
    lamb: float = 1.05,
):
    log_pf = jnp.cumsum(log_pf, axis=1)
    log_pb = jnp.cumsum(log_pb, axis=1)

    log_pf = log_pf[:, i]
    log_pb = log_pb[:, i]

    loss = log_pf - log_pb - log_F[:, i + 1] + log_F[:, [0] * len(i)]

    flag = jnp.expand_dims(last_idx, 1) >= jnp.expand_dims(i, 0)
    flag = flag.astype(loss.dtype)
    weights = jnp.expand_dims(lamb**i, axis=0) * flag

    loss = ((loss**2) * weights).sum(axis=1) / weights.sum(axis=1)

    return loss.mean()


@dataclass
class EnvironmentSnapshot(Generic[T]):
    is_active: jax.Array
    actions: Optional[jax.Array]
    states: T


@struct.dataclass
class GFlowNetConfig:
    criterion: str = struct.field(pytree_node=False, default="tb")
    lamb_subtb: float = struct.field(pytree_node=False, default=0.99)

    fcs_num_iterations: int = struct.field(pytree_node=False, default=32)
    fcs_num_back_traj: int = struct.field(pytree_node=False, default=8)

    step_size_regularizer: float = struct.field(pytree_node=False, default=1e-5)


@struct.dataclass
class GFlowNetState:
    env_state: EnvState

    log_pf: jax.Array
    log_pb: jax.Array

    idx: int

    log_F: jax.Array


@dataclass
class GFlowNetSampleOutput:
    log_pf: jax.Array
    log_pb: jax.Array
    log_F: jax.Array
    last_idx: jax.Array
    states: Environment
    log_rewards: jax.Array
    snapshots: list[EnvironmentSnapshot]


class Constant(nnx.Module):
    def __call__(self, state: EnvState):
        return jnp.zeros((state.batch_size, 1))


class GFlowNet(Generic[T], nnx.Module):
    def __init__(
        self,
        pf: ForwardPolicyBase,
        pb: BackwardPolicyBase,
        flow_function: FlowFunctionBase,
        log_reward: LogRewardBase,
        config: GFlowNetConfig,
    ):
        super(GFlowNet, self).__init__()
        self.pf = pf
        self.pb = pb

        self.flow_function = flow_function or Constant()

        self.log_z = nnx.Param(value=0.0)

        self.criterion = config.criterion
        self.lamb_subtb = config.lamb_subtb

        self.log_reward = log_reward
        self.config = config

        self.state = nnx.data(None)
        self.step_size_reg = config.step_size_regularizer

    def init_state(self, env_state: EnvState):
        batch_size = env_state.batch_size
        log_pf = jnp.zeros((batch_size, env_state.max_trajectory_length))
        log_pb = jnp.zeros((batch_size, env_state.max_trajectory_length))

        log_F = jnp.zeros((batch_size, env_state.max_trajectory_length + 1))
        log_F = log_F.at[:, 0].set(self.flow_function(env_state).squeeze(axis=-1))

        self.state = GFlowNetState(
            env_state=env_state,
            log_pf=log_pf,
            log_pb=log_pb,
            log_F=log_F,
            idx=0,
        )

        self.pf.lazy_init(env_state)
        self.pb.lazy_init(env_state)

    def lazy_init_fwd(self):
        self.pf.lazy_init(self.state.env_state)

    def lazy_init_bcw(self):
        self.pb.lazy_init(self.state.env_state)

    def __call__(self, batch_state: T):
        out = self.sample_traj(batch_state)

        loss = self.evaluate_loss_on_transitions(out)
        return loss

    def evaluate_loss(self, log_rewards: jax.Array):
        match self.criterion:
            case "tb":
                loss = self._trajectory_balance(log_rewards)
            case "cb":
                loss = self._contrastive_balance_full(log_rewards)
            case "db":
                loss = self._detailed_balance(log_rewards)
            case "subtb":
                loss = self._subtrajectory_balance(log_rewards)
            case _:
                raise ValueError(f"{self.criterion} should be either tb, cb, db, or dbc")

        return loss

    def sample_traj(self, batch_state: T) -> GFlowNetSampleOutput:
        max_traj_len = batch_state.max_trajectory_length
        log_pf = jnp.zeros((batch_state.batch_size, max_traj_len))
        log_pb = jnp.zeros((batch_state.batch_size, max_traj_len))
        log_F = jnp.zeros((batch_state.batch_size, max_traj_len + 1))

        steps_taken = jnp.zeros((batch_state.batch_size,), dtype=jnp.int32)

        snapshots: list[EnvironmentSnapshot[T]] = []

        for t in range(max_traj_len):
            is_active = ~batch_state.stopped.astype(bool)
            if not bool(jnp.any(is_active)):
                break

            out_pf = self.pf.sample_actions(batch_state, history=[s.states for s in snapshots])
            actions = jnp.where(is_active, out_pf.actions, 0)

            log_Ft = self.flow_function(batch_state)  # (batch_size, 1)
            log_Ft = log_Ft.squeeze(axis=1)  # (batch_size,)

            # We save the state prior to modification into the history
            snapshots.append(
                EnvironmentSnapshot(
                    is_active=is_active,
                    actions=actions,
                    states=deepcopy(batch_state),
                )
            )

            # Apply the actions
            batch_state.apply(actions, active_mask=is_active)

            # Corresponding backward actions; we could apply a non-Markovian, post-hoc processing to this
            out_pb = self.pb.sample_actions(batch_state, batch_state.fwd_to_bcw_actions(actions))

            # Save values
            log_pf = log_pf.at[:, t].set(jnp.where(is_active, out_pf.log_pf, 0.0))
            log_pb = log_pb.at[:, t].set(jnp.where(is_active, out_pb.log_pb, 0.0))

            log_F = log_F.at[:, t].set(jnp.where(is_active, log_Ft, 0.0))

            steps_taken = steps_taken + is_active.astype(jnp.int32)

        log_rewards = jax.lax.stop_gradient(self.log_reward(batch_state))
        log_F = log_F.at[batch_state.batch_ids, steps_taken].set(log_rewards)

        return GFlowNetSampleOutput(
            log_pf=log_pf,
            log_pb=log_pb,
            log_F=log_F,
            last_idx=steps_taken,
            log_rewards=log_rewards,
            states=batch_state,
            snapshots=snapshots,
        )

    def evaluate_on_snapshots(self, snapshots: list[EnvironmentSnapshot[T]]):
        first_snapshot = snapshots[0]

        # We copy the first state to avoid modifying the origional snapshots
        states = deepcopy(first_snapshot.states)

        max_traj_len = states.max_trajectory_length
        log_pf = jnp.zeros((states.batch_size, max_traj_len))
        log_pb = jnp.zeros((states.batch_size, max_traj_len))
        log_F = jnp.zeros((states.batch_size, max_traj_len + 1))

        steps_taken = jnp.zeros((states.batch_size,), dtype=jnp.int32)

        for t in range(max_traj_len):
            is_active = ~states.stopped.astype(bool)
            if not bool(jnp.any(is_active)):
                break

            out_pf = self.pf.sample_actions(
                states,
                history=[s.states for s in snapshots],
                actions=snapshots[t].actions,
            )
            actions = jnp.where(is_active, out_pf.actions, 0)

            log_Ft = self.flow_function(states)  # (batch_size, 1)
            log_Ft = log_Ft.squeeze(axis=1)  # (batch_size,)

            # Apply the actions
            states.apply(actions, active_mask=is_active)

            # Corresponding backward actions; we could apply a non-Markovian, post-hoc processing to this
            out_pb = self.pb.sample_actions(states, states.fwd_to_bcw_actions(actions))

            # Save values
            log_pf = log_pf.at[:, t].set(jnp.where(is_active, out_pf.log_pf, 0.0))
            log_pb = log_pb.at[:, t].set(jnp.where(is_active, out_pb.log_pb, 0.0))

            log_F = log_F.at[:, t].set(jnp.where(is_active, log_Ft, 0.0))

            steps_taken = steps_taken + is_active.astype(jnp.int32)

        log_rewards = jax.lax.stop_gradient(self.log_reward(states))
        log_F = log_F.at[states.batch_ids, steps_taken].set(log_rewards)

        return GFlowNetSampleOutput(
            log_pf=log_pf,
            log_pb=log_pb,
            log_F=log_F,
            last_idx=steps_taken,
            log_rewards=log_rewards,
            states=states,
            snapshots=snapshots,
        )

    def num_parameters(self) -> int:
        params = nnx.state(self, nnx.Param)

        def accumulate(total: int, value):
            if isinstance(value, jax.Array):
                return total + int(value.size)
            if hasattr(value, "value") and isinstance(value.value, jax.Array):
                return total + int(value.value.size)
            return total

        return jax.tree_util.tree_reduce(accumulate, params, 0)

    def _trajectory_balance(self, log_rewards: jax.Array):
        loss = (self.state.log_pf - self.state.log_pb).sum(axis=1) - log_rewards + self.log_z
        return huber_loss(loss, delta=1.0)

    def _detailed_balance(self, log_rewards: jax.Array):
        loss = jnp.power(
            self.state.log_pf + self.state.log_F[:, :-1] - self.state.log_pb - self.state.log_F[:, 1:],
            2,
        )
        denom = jnp.maximum(self.state.last_idx, 1)
        loss_avg = (loss.sum(axis=1) / denom).mean()
        return loss_avg

    # def _subtrajectory_balance(self, out: GFlowNetSampleOutput):
    #     max_traj_length = out.states.max_trajectory_length

    #     log_pf = jnp.concatenate([jnp.zeros_like(out.log_pf[:, [0]]), out.log_pf])
    #     log_pb = jnp.concatenate([jnp.zeros_like(out.log_pb[:, [0]]), out.log_pb])

    #     i, j = jnp.triu_indices(max_traj_length + 1, max_traj_length + 1, offset=1)
    #     return _subtraj_loss(i, j, log_pf, log_pb, out.log_F, out.last_idx, lamb=self.lamb_subtb)

    def _subtrajectory_balance(self, out: GFlowNetSampleOutput):
        i = jnp.arange(out.log_pf.shape[1])
        return _subtraj_loss(i, out.log_pf, out.log_pb, out.log_F, out.last_idx, lamb=self.lamb_subtb)

    def _contrastive_balance_full(self, log_rewards: jax.Array):
        loss = (self.state.log_pf - self.state.log_pb).sum(axis=1) - log_rewards
        loss = loss[:, None] - loss[None, :]
        return huber_loss(loss, delta=1.0)

    def sample(self, batch_state: T) -> T:
        out = self.sample_traj(batch_state)
        return out.states

    def marginal_prob(self, batch_state: T, copy_env: bool = False) -> jax.Array:
        # Use importance sampling to estimate the marginal probabilities
        if copy_env:
            batch_state = deepcopy(batch_state)
        log_pf = jnp.zeros((batch_state.batch_size, batch_state.max_trajectory_length))
        log_pb = jnp.zeros((batch_state.batch_size, batch_state.max_trajectory_length))

        idx = 0
        # I will first sample the backward trajectories,
        # and then I will compute the forward probabilities for each trajectory.
        snapshots: list[EnvironmentSnapshot[T]] = []
        is_initial = batch_state.is_initial.astype(bool)
        for _ in range(batch_state.max_trajectory_length):
            is_active = ~is_initial
            if not bool(jnp.any(is_active)):
                break

            # Estimate the backward probabilities
            out_pb = self.pb.sample_actions(batch_state)

            backward_actions = jnp.where(is_active, out_pb.actions, 0)
            forward_actions = batch_state.backward(backward_actions, active_mask=is_active)

            snapshots.append(
                EnvironmentSnapshot(
                    is_active=is_active,
                    actions=forward_actions,
                    states=deepcopy(batch_state),
                )
            )

            backward_log_prob = jnp.where(is_active, out_pb.log_pb, 0.0)
            log_pb = log_pb.at[:, idx].set(backward_log_prob)

            is_initial = batch_state.is_initial.astype(bool)
            idx += 1

        idx = 0
        for snapshot in snapshots[::-1]:
            # Estimate the forward probabilities
            out_pf = self.pf.sample_actions(snapshot.states, actions=snapshot.actions)
            forward_log_prob = jnp.where(snapshot.is_active, out_pf.log_pf, 0.0)
            log_pf = log_pf.at[:, idx].set(forward_log_prob)
            idx += 1

        marginal_log = (log_pf - log_pb).sum(axis=1)  # = log_pf.sum + log_pb.sum
        return marginal_log

    def sample_many_backward(self, batch_states: T, num_trajectories: int):
        marginal_log = jnp.zeros((batch_states.batch_size, num_trajectories))
        for idx in range(num_trajectories):
            marginal_log = marginal_log.at[:, idx].set(self.marginal_prob(batch_states, copy_env=True))
        return marginal_log

    def evaluate_fcs_on_trajectories(self, states: T, return_marginals: bool = False):
        iterations = self.config.fcs_num_iterations
        num_trajectories = self.config.fcs_num_back_traj

        fcs = 0
        marginals = {"est": list(), "true": list()}
        for _ in range(iterations):
            # Get terminal states
            terminal_states = self.sample(deepcopy(states))

            # Get backward trajectories
            marginal_est = list()
            for _ in range(num_trajectories):
                marginal_prob = self.marginal_prob(terminal_states, copy_env=True)
                marginal_prob = jnp.expand_dims(marginal_prob, axis=1)
                marginal_est.append(marginal_prob)

            marginal_est = jnp.concatenate(marginal_est, axis=1)
            marginal_est = jax.nn.logsumexp(marginal_est, axis=1) - jnp.log(num_trajectories)

            marginal_true = self.log_reward(terminal_states)

            # Normalize
            marginal_est_norm = marginal_est - jax.nn.logsumexp(marginal_est, axis=0)
            marginal_true_norm = marginal_true - jax.nn.logsumexp(marginal_true, axis=0)

            fcs += jnp.abs(jnp.exp(marginal_true_norm) - jnp.exp(marginal_est_norm)).sum() / 2

            marginals["est"].append(marginal_est_norm)
            marginals["true"].append(marginal_true_norm)

        fcs /= iterations
        if return_marginals:
            return fcs, {
                "est": jnp.hstack(marginals["est"]),
                "true": jnp.hstack(marginals["true"]),
            }
        return fcs

    def set_policy_eps(self, eps: float):
        self.pf.set_eps(eps)

    def reset_policy_eps(self):
        self.pf.reset_eps()

    def entropy(self, initial_state: T, num_samples: int):
        # This provides an (empirical, MC-based) estimate of the GFlowNet's entropy
        # Recall: H[p] = E_p [- log p]
        entropy = jnp.zeros(initial_state.batch_size)

        # Entropy should be computed on-policy
        self.pf.set_is_off_policy(False)
        for _ in range(num_samples):
            state = deepcopy(initial_state)
            trajectory = self.sample_traj(state)
            log_pf = trajectory.log_pf.sum(axis=1)
            entropy -= log_pf
        self.pf.set_is_off_policy(True)

        entropy = (entropy / num_samples).mean()

        # We can also compute the normalized entropy,
        # in case the state space size is provided
        if initial_state.log_space_size is not None:
            entropy = entropy / jnp.log(initial_state.log_space_size)

        return entropy


class TemperatureSchedule(nnx.Module):
    def __init__(self, log_reward: LogRewardBase, initial_t: float = 1.0, final_t: float = 1.0):
        self.log_reward = log_reward
        self.initial_t = jnp.array(initial_t)
        self.final_t = jnp.array(final_t)

        self.t = self.initial_t
        self.k = 0
        self.c = 1e-3  # cauchy schedule

        assert self.final_t <= self.initial_t

    def step(self) -> LogRewardBase:
        self.k += 1
        self.t = self.initial_t / (1 + self.c * self.k) + self.final_t
        self.log_reward.temperature = self.t
        return self.log_reward

    def finish(self) -> LogRewardBase:
        self.t = self.final_t
        self.log_reward.temperature = self.t
        return self.log_reward


def create_adam_optimizer(
    model: GFlowNet,
    default_lr: float = 1e-2,
    flow_function_lr: float = 1e-2,
    log_z_lr: float = 1e-1,
    private_lr: float = 1e-2,
    log_pb_lr: float = 1e-3,
    transition_steps: int = 512,
):
    params = nnx.state(model, nnx.Param)

    def label_param(path, _):
        # If the parameter contains "._W" or ""._z", then label it as private
        if "_W" in path or "_z_norm" in path:
            return "private"
        if "flow_function" in path:
            return "flow_function"
        if path == ("log_z",):
            return "log_z"
        if "pb" in path:
            return "log_pb"

        return "default"

    param_labels = nnx.map_state(label_param, params)

    power = 1.0

    schedule_flow_function = optax.schedules.polynomial_schedule(
        init_value=flow_function_lr,
        end_value=flow_function_lr * 1e-2,
        power=power,
        transition_steps=transition_steps,
    )

    schedule_policy = optax.schedules.polynomial_schedule(
        init_value=default_lr,
        end_value=default_lr * 1e-2,
        power=power,
        transition_steps=transition_steps,
    )

    schedule_log_pb = optax.schedules.polynomial_schedule(
        init_value=log_pb_lr,
        end_value=log_pb_lr * 1e-2,
        power=power,
        transition_steps=transition_steps,
    )

    schedule_log_z = optax.schedules.polynomial_schedule(
        init_value=log_z_lr,
        end_value=log_z_lr * 1e-2,
        power=power,
        transition_steps=transition_steps,
    )

    tx = optax.multi_transform(
        {
            "log_z": optax.adam(learning_rate=schedule_log_z, nesterov=True),
            "default": optax.adam(learning_rate=schedule_policy),
            "flow_function": optax.adam(learning_rate=schedule_flow_function),
            "private": optax.adam(learning_rate=private_lr),
            "log_pb": optax.adam(learning_rate=schedule_log_pb),
        },
        param_labels,
    )

    return nnx.Optimizer(model, tx=tx, wrt=nnx.Param)
