import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import tqdm
import json
import os

from copy import deepcopy
from dataclasses import dataclass, asdict

# from nais.gym.hypergrids import Hypergrids, HypergridState
# from nais.gym.lines import Lines, LineState

from nais.gym import EnvironmentEnum

from nais.gflownet import GFlowNet
from nais.gym.base import LogRewardBase, LogRewardEnum
from nais.policies.base import RecurrenceType
from nais.queues import ModeQueue, TopKQueue, MetadataMetrics


# def draw_hypergrid(
#     sampler: GFlowNet[Hypergrids],
#     env: Hypergrids,
#     num_samples: int,
#     log_rewards: jax.Array | None = None,
# ):
#     assert env.dims == 2, "Drawable hypergrids must be 2-dimensional"
#     iterations = (num_samples + env.batch_size) // env.batch_size

#     sampler.set_policy_eps(0.0)

#     states: list[jax.Array] = []
#     for _ in tqdm.trange(iterations):
#         samples = deepcopy(env)
#         samples = sampler.sample(samples)
#         states.append(samples.state)

#     sampler.reset_policy_eps()

#     # Compute the average concentration of the states in the hypergrids
#     states = jnp.stack(states)
#     states = states.astype(jnp.int32)

#     grid = jnp.zeros((env.size, env.size))
#     grid = grid.at[states[:, 0], states[:, 1]].add(1)
#     grid = grid / grid.sum(keepdims=True)

#     if log_rewards is not None:
#         plt.imshow(jnp.exp(log_rewards))
#         plt.colorbar()
#         plt.show()
#         return grid

#     plt.imshow(grid)
#     plt.colorbar()
#     plt.show()

#     return grid


# def get_grid_empirical_dist(
#     sampler: GFlowNet[Hypergrids],
#     env: Hypergrids,
#     num_samples: int,
#     log_reward: LogRewardBase,
#     show_figure: bool = False,
# ) -> tuple[jax.Array, jax.Array]:
#     # We first create a hypergrid in which each state is on a different square
#     states = jnp.meshgrid(jnp.arange(env.size), jnp.arange(env.size))
#     states = jnp.stack(states, axis=0)
#     states = states.reshape(env.dims, -1).T

#     # We create a hypergrid with the appropriate batch size
#     env.config.batch_size = states.shape[0]
#     env = Hypergrids(env.size, env.dims, env.config)

#     # Update the masks to only allow the stop action
#     forward_mask = jnp.zeros((env.batch_size, env.dims + 1))
#     backward_mask = jnp.zeros((env.batch_size, env.dims + 1))

#     forward_mask = forward_mask.at[:, -1].set(1.0)
#     backward_mask = backward_mask.at[:, -1].set(1.0)

#     stopped = jnp.ones((env.batch_size,))
#     is_initial = jnp.zeros((env.batch_size,))

#     # Then, we set the states to the appropriate values
#     env._state = HypergridState(
#         state=states,
#         forward_mask=forward_mask,
#         backward_mask=backward_mask,
#         stopped=stopped,
#         is_initial=is_initial,
#     )
#     env._sync_views()

#     # We then compute the empirical distribution
#     empirical_dist = sampler.sample_many_backward(env, num_trajectories=num_samples)
#     empirical_dist = jax.nn.logsumexp(empirical_dist, axis=1) - jnp.log(num_samples)
#     empirical_dist = empirical_dist - jax.nn.logsumexp(empirical_dist, axis=0)

#     empirical_dist = jnp.exp(empirical_dist)

#     # We afterwards plot the empirical distribution
#     grid = jnp.zeros((env.size, env.size))
#     grid = grid.at[env.state[:, 0], env.state[:, 1]].set(empirical_dist)
#     if show_figure:
#         plt.imshow(grid)
#         plt.colorbar()
#         plt.show()

#     # We also compute the log_reward for the complete grid
#     log_rewards = log_reward(env)
#     target_dist = log_rewards - jax.nn.logsumexp(log_rewards, axis=0)
#     target_dist = jnp.exp(target_dist)

#     grid_target = jnp.zeros((env.size, env.size))
#     grid_target = grid_target.at[env.state[:, 0], env.state[:, 1]].set(target_dist)

#     return grid, grid_target


# def get_lines_empirical_dist(
#     sampler: GFlowNet[Lines],
#     env: Lines,
#     num_samples: int,
#     log_reward: LogRewardBase,
#     show_figure: bool = False,
# ) -> tuple[jax.Array, jax.Array]:
#     # We first create an agent with batch size equal to the length
#     config = env.config
#     config.batch_size = env.length
#     env = Lines(env.length, env.max_step_size, env.config)

#     # We then update the forward and backward masks as if the state was finished
#     forward_mask = jnp.zeros((env.batch_size, env.num_actions))
#     backward_mask = jnp.zeros((env.batch_size, env.num_actions))

#     forward_mask = forward_mask.at[:, -1].set(1.0)
#     backward_mask = backward_mask.at[:, -1].set(1.0)

#     # Afterwards, we update the actual state with a batch in each line coordinate
#     states = jnp.arange(env.length)
#     env._state = LineState(
#         state=states,
#         forward_mask=forward_mask,
#         backward_mask=backward_mask,
#         stopped=jnp.ones_like(env.stopped),
#         is_initial=jnp.zeros_like(env.is_initial),
#     )
#     env._sync_views()

#     # Building on this, we compute the empirical and target distributions
#     empirical_dist = sampler.sample_many_backward(env, num_trajectories=num_samples)
#     empirical_dist = jax.nn.logsumexp(empirical_dist, axis=1) - jnp.log(num_samples)
#     empirical_dist = empirical_dist - jax.nn.logsumexp(empirical_dist, axis=0)
#     empirical_dist = jnp.exp(empirical_dist)

#     log_rewards = log_reward(env)
#     target_dist = log_rewards - jax.nn.logsumexp(log_rewards, axis=0)
#     target_dist = jnp.exp(target_dist)

#     # We then plot both empirical and target distributions into a grid
#     grid = jnp.zeros((2, env.length))
#     grid = grid.at[0, :].set(empirical_dist)
#     grid = grid.at[1, :].set(target_dist)
#     if show_figure:
#         plt.imshow(grid)
#         plt.yticks([0, 1], ["empirical", "target"])
#         plt.show()

#     return empirical_dist, target_dist


def get_folder_name_from_config(
    env: EnvironmentEnum,
    recurrence_type: RecurrenceType,
    criterion: str,
    log_reward_type: LogRewardEnum,
):
    root = "_metrics"
    if not os.path.exists(root):
        os.mkdir(root)
    return f"{root}/{env.value}#{recurrence_type.value}#{criterion}#{log_reward_type.value}"


@dataclass
class StoredMetrics:
    fcs_prior_to_training: float
    fcs_after_training: float
    entropy: float


@dataclass
class StoredData:
    mode_queue: ModeQueue
    topk_queue: TopKQueue
    learned_dist: jax.Array
    target_dist: jax.Array
    mode_th: float
    metadata: list[MetadataMetrics]


def save_state(
    fcs_prior_to_training: float,
    fcs_after_training: float,
    entropy: float,
    mode_queue: ModeQueue,
    topk_queue: TopKQueue,
    learned_dist: jax.Array,
    target_dist: jax.Array,
    metadata: list[MetadataMetrics],
    folder: str,
):
    os.makedirs(folder, exist_ok=True)
    # This function saves the state of the environment prior
    # and after training.
    metrics = StoredMetrics(
        fcs_prior_to_training=fcs_prior_to_training.item(),
        fcs_after_training=fcs_after_training.item(),
        entropy=entropy.item(),
    )

    data = StoredData(
        mode_queue=mode_queue.history,
        topk_queue=topk_queue.history,
        learned_dist=learned_dist.tolist(),
        target_dist=target_dist.tolist(),
        mode_th=mode_queue.th.item(),
        metadata=metadata,
    )

    with open(f"{folder}/data.json", "w") as f:
        json.dump(asdict(data), f)

    with open(f"{folder}/metrics.json", "w") as f:
        json.dump(asdict(metrics), f)

    return metrics, data
