import jax
import jax.numpy as jnp
import flax.nnx as nnx
import flax.struct as struct

from copy import deepcopy
from functools import partial
from dataclasses import dataclass

from nais.gym.base import Environment, EnvState
from nais.gflownet import GFlowNet


def _top_k(arr: jax.Array, k: int):
    return jax.lax.top_k(arr, k)


@struct.dataclass
class TopKQueue:
    size: int
    dim: int
    queue_states: jax.Array
    queue_log_rewards: jax.Array
    history: jax.Array = struct.field(pytree_node=False)

    @classmethod
    def create(cls, states: EnvState, size: int):
        dim = states.state.shape[1] if len(states.state.shape) > 1 else 1
        return cls(
            size=size,
            dim=dim,
            queue_states=jnp.zeros((0, dim)),
            queue_log_rewards=jnp.zeros((0,)),
            history=jnp.zeros((0,)),
        )

    # def __init__(self, states: EnvState, size: int):
    #     # Keep track of the average reward of the K most rewarding samples
    #     # throughout training
    #     self.size = size
    #     self.dim = states.state.shape[1]

    #     # Initially the arrays do not contain any elements
    #     self.queue_states = jnp.zeros((0, self.dim))
    #     self.queue_log_rewards = jnp.zeros((0,))

    #     # We should keep track of the average reward
    #     self.history = []

    def push(self, states: EnvState, log_rewards: jax.Array):
        state = states.state
        if len(state.shape) <= 1:
            state = state[:, None]

        # Push a new sample into the queue
        states, indices = jnp.unique(state, axis=0, return_index=True)
        log_rewards = log_rewards[indices]
        # We first concatenate state with self.queue_states
        queue_states = jnp.vstack([states, self.queue_states])
        queue_log_rewards = jnp.hstack([log_rewards, self.queue_log_rewards])

        # Fetch for unique states
        queue_states, indices = jnp.unique(queue_states, axis=0, return_index=True)
        queue_log_rewards = queue_log_rewards[indices]

        values, indices = _top_k(
            queue_log_rewards, k=min(self.size, len(queue_log_rewards))
        )

        queue_log_rewards = values
        queue_states = queue_states[indices]

        history = jnp.hstack([self.history, self.queue_log_rewards.mean()[None]])

        return self.replace(
            queue_states=queue_states,
            queue_log_rewards=queue_log_rewards,
            history=history,
        )
        # self.history.append(self.queue_log_rewards.mean().item())
        # # pass


@struct.dataclass
class ModeQueue:
    th: float
    dim: int
    queue_states: jax.Array
    queue_log_rewards: jax.Array
    history: jax.Array = struct.field(pytree_node=False)  # <- ignored by JAX

    @classmethod
    def create(cls, states: EnvState, th: float):
        dim = states.state.shape[1] if len(states.state.shape) > 1 else 1
        return cls(
            th=th,
            dim=dim,
            queue_states=jnp.zeros((0, dim)),
            queue_log_rewards=jnp.zeros((0,)),
            history=jnp.zeros((0,)),
        )

    def push(self, states: EnvState, log_rewards: jax.Array):
        state = states.state
        if len(state.shape) <= 1:
            state = state[:, None]

        new_states = jnp.vstack([state, self.queue_states])
        new_log_rewards = jnp.hstack([log_rewards, self.queue_log_rewards])

        new_states, indices = jnp.unique(
            new_states,
            axis=0,
            return_index=True,
        )
        new_log_rewards = new_log_rewards[indices]

        valid = new_log_rewards >= self.th

        new_history = jnp.hstack([self.history, valid.sum()[None]])

        return self.replace(
            history=new_history,
            queue_states=new_states[valid],
            queue_log_rewards=new_log_rewards[valid],
        )


# def flatten_nested_dict(d: nnx.State):
#     arrays = []

#     def recurse(subdict: nnx.State):
#         for key in sorted(subdict.keys()):
#             val = subdict[key]
#             if isinstance(val, nnx.State):
#                 recurse(val)
#             else:
#                 arrays.append(val.ravel())

#     recurse(d)
#     return jnp.concatenate(arrays)


# def get_stepwise_grad_correlation(model: GFlowNet, states: Environment):
#     # This will be a bit expensive, but that is ok

#     out = model.sample_traj(states)

#     # compute partial L / partial theta
#     def loss_fn(model, out):
#         out = model.evaluate_on_snapshots(out.snapshots)
#         return model.evaluate_loss_on_transitions(out)

#     _, grads_loss = nnx.value_and_grad(loss_fn)(model, out)  # partial L / partial theta
#     grads_loss = flatten_nested_dict(grads_loss["pf"])

#     # compute partial log_pf / partial theta
#     def loss_fn(model, out, t):
#         out = model.evaluate_on_snapshots(out.snapshots)
#         log_pf = out.log_pf[:, t]
#         return jnp.mean(log_pf)

#     stepwise_grad_correlation = []

#     cov_loss = grads_loss[:, None] @ grads_loss[None, :]
#     eigvals_loss = jnp.linalg.eig(cov_loss)

#     for t in range(states.max_trajectory_length):
#         _, grads_log_pf = nnx.value_and_grad(loss_fn)(model, out, t)

#         grads_log_pf = flatten_nested_dict(grads_log_pf["pf"])

#         cov = grads_log_pf[:, None] @ grads_log_pf[None, :]
#         eigvals = jnp.linalg.eig(cov)

#         stepwise_grad_correlation.append(jnp.real(eigvals.eigenvalues[0] / eigvals_loss.eigenvalues[0]).item())

#     print(stepwise_grad_correlation)
#     return stepwise_grad_correlation


@struct.dataclass
class MetadataMetrics:
    # Uniform is an arbitrarily chosen reference distribution
    kl_div_to_uniform: float
    policy_entropy: float


@struct.dataclass
class MetadataQueue:
    history: list[MetadataMetrics] = struct.field(pytree_node=False)

    @classmethod
    def create(cls):
        return cls(history=[])

    def push(self, states: EnvState, gflownet: GFlowNet):
        # These metrics represent the KL/entropy at the initial state
        kl_div = gflownet.pf.kl_to_uniform(states)
        entropy = gflownet.pf.entropy(states)

        history = self.history + [
            MetadataMetrics(
                kl_div_to_uniform=kl_div.item(), policy_entropy=entropy.item()
            ),
        ]

        return self.replace(history=history)
