from typing import Callable

import jax
import jax.numpy as jnp
import Levenshtein
import numpy as np
from jax import Array
from tqdm import tqdm
from vendi_score import vendi

from medium_rl.alg.tgm import gen_traj
from medium_rl.envs.sequence_env import SequenceEnv
from medium_rl.init import RunState
from medium_rl.policy import make_alg_policy


def eval(
    run_state: RunState,
    forward: Callable,
    policy_fn: Callable,
    env: SequenceEnv,
    num_samples: int,
    top_k: int,
):
    samples = gen_eval_samples(run_state, forward, policy_fn, env, num_samples)
    rewards, embeds = batch_eval_samples(samples, env)

    top_k_idx = jax.lax.top_k(rewards, top_k)[1]
    top_embeds = embeds[top_k_idx]

    mean_reward = rewards[top_k_idx].mean().item()
    mean_dist = vendi.score_dual(top_embeds, normalize=True).item()

    res = {
        "mean_all_reward": rewards.mean().item(),
        "mean_top_k_reward": mean_reward,
        "vendi_score": mean_dist,
    }

    return res, rewards, embeds, samples, samples[top_k_idx]


def batch_eval_samples(samples: Array, env: SequenceEnv, batch_size: int = 256):
    num_samples = samples.shape[0]
    rewards, embeds = (
        jnp.zeros((num_samples,)),
        jnp.zeros((num_samples, env.proxy.embed_dim)),
    )

    for i in range(num_samples // batch_size):
        curr_rewards, curr_embeds = env.proxy.get_embed(samples[i * batch_size : (i + 1) * batch_size])
        rewards = rewards.at[i * batch_size : (i + 1) * batch_size].set(curr_rewards.squeeze(-1))
        embeds = embeds.at[i * batch_size : (i + 1) * batch_size].set(curr_embeds)

    return rewards, embeds


def gen_eval_samples(
    run_state: RunState,
    forward: Callable,
    policy_fn: Callable,
    env: SequenceEnv,
    num_eval_samples: int,
):
    @jax.jit
    def _sample_obj(run_state, _):
        run_state, sub_traj_batch = gen_traj(run_state, forward, policy_fn, env)
        objs = sub_traj_batch.obs
        return run_state, objs

    assert num_eval_samples % run_state.env_state.obs.shape[0] == 0, (
        f"num_eval_samples={num_eval_samples} must be divisible by num_envs={run_state.env_state.obs.shape[0]}"
    )
    _, samples = jax.lax.scan(_sample_obj, run_state, None, num_eval_samples // run_state.env_state.obs.shape[0])

    samples = samples.reshape(-1, samples.shape[-1])  # [B, N, D] -> [num_eval_samples, D]
    return samples


def eval_temp(
    run_state: RunState,
    forward: Callable,
    env: SequenceEnv,
    alg_cfg,
    num_samples: int,
    top_k: int,
    inverse_temp_mod: float,
):
    policy_fn = make_alg_policy(0.0, alg_cfg, inverse_temp_mod)
    return eval(run_state, forward, policy_fn, env, num_samples, top_k)


def eval_sweep_temp(
    run_state: RunState,
    forward: Callable,
    env: SequenceEnv,
    alg_cfg,
    num_samples: int,
    top_k: int,
    inverse_temp_mods: list[float],
):
    all_res = {}
    for inverse_temp_mod in tqdm(inverse_temp_mods):
        res, rewards, embeds, samples, top_samples = eval_temp(
            run_state, forward, env, alg_cfg, num_samples, top_k, inverse_temp_mod
        )
        all_res[inverse_temp_mod] = {
            "inverse_temp_mod": inverse_temp_mod,
            "rewards": rewards,
            "embeds": embeds,
            "samples": samples,
            "top_samples": top_samples,
            **res,
        }

    return all_res


def eval_modes(run_state: RunState, forward: Callable, env: SequenceEnv, cfg):
    sweep_res = eval_sweep_temp(run_state, forward, env, cfg.alg, cfg.num_eval_samples, cfg.top_k, cfg.eval_sweep_temps)
    mode_res = top_k_modes(sweep_res, env.alphabet, cfg.top_k, cfg.env.mode_delta)
    return sweep_res, mode_res


""" GREEDY MODE EVAL """


def list_to_str(seq, alphabet):
    # Remove the initial token (CLS), and the final tokens (EOS, PAD)
    start = np.where(seq == 0)[0][0]
    end = np.where(seq == 2)[0][0]
    char_list = [alphabet[elem] for elem in seq[start + 1 : end]]
    return "".join(char_list)


def get_levenshtein_dist(seq1, seq2, alphabet):
    str1 = list_to_str(seq1, alphabet)
    str2 = list_to_str(seq2, alphabet)
    # dist = Levenshtein.distance(str1, str2) / max(len(str1), len(str2))
    dist = Levenshtein.distance(str1, str2)
    return dist


def top_k_modes(sweep_res, alphabet, max_num_modes=100, delta=5):
    # For each setting, concat the data, then sort the samples according to the reward, and find the top 100 DISTINCT modes where the distance is greater than delta.
    all_samps = [sweep_res[elem]["samples"] for elem in sweep_res]
    all_rews = [sweep_res[elem]["rewards"] for elem in sweep_res]

    all_samps = np.array(all_samps)
    all_rews = np.array(all_rews)

    shape = all_samps.shape
    all_samps = all_samps.reshape(shape[0] * shape[1], shape[2])
    all_rews = all_rews.reshape(-1)
    best_idxs = np.argsort(all_rews)[::-1]

    all_samps = all_samps[best_idxs]
    all_rews = all_rews[best_idxs]

    modes = [all_samps[0]]
    mode_rewards = [all_rews[0]]

    for i, new_sample in enumerate(all_samps):
        add_samp = True
        for mode in modes:
            dist = get_levenshtein_dist(new_sample, mode, alphabet)
            if dist < delta:
                # Need this new sample to be distinct from all prior modes.
                add_samp = False
                break
        if add_samp:
            modes.append(new_sample)
            mode_rewards.append(all_rews[i])
        if len(modes) >= max_num_modes:
            break

    mean_reward = np.mean(mode_rewards)

    return {
        "modes": [list_to_str(mode, alphabet) for mode in modes],
        "mode_rewards": np.array(mode_rewards),
        "mean_mode_reward": mean_reward.item(),
    }
