import os
import pickle
import random

import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
from jax import Array
from Levenshtein import distance

from medium_rl.network.encoder import EncoderTransformer
from medium_rl.utils import load_model

DIR = os.path.dirname(os.path.abspath(__file__))


def load_proxy(path: str):
    model_cfg, params = load_model(os.path.join(path, "proxy.pkl"))
    with open(os.path.join(path, "val_stats.pkl"), "rb") as f:
        val_stats = pickle.load(f)

    return model_cfg, params, val_stats


class Proxy:
    def evaluate(self, x: Array) -> Array:
        pass


class EncoderProxy(Proxy):
    def __init__(self, name: str) -> None:
        self.name = name
        model_cfg, params, val_stats = load_proxy(os.path.join(DIR, name))

        forward = hk.transform(
            lambda x, is_training=False, get_embed=False: EncoderTransformer(**model_cfg)(x, is_training, get_embed)
        )
        self.forward = forward
        self.params = params
        self.model_cfg = model_cfg
        self.mean = val_stats["mean"]
        self.std = val_stats["std"]

        self.embed_dim = model_cfg["embed_dim"]
        self.eval_fn = self.make_eval_fn()
        self.embed_fn = self.make_embed_fn()

    def make_eval_fn(self):
        @jax.jit
        def eval(x):
            output_logits = self.forward.apply(self.params, None, x, is_training=False)
            return self.get_exponential_normalized_logit(output_logits)

        return eval

    def make_embed_fn(self):
        @jax.jit
        def embed(x):
            logits, embed = self.forward.apply(self.params, None, x, is_training=False, get_embed=True)
            logits = self.get_exponential_normalized_logit(logits)
            return logits, embed

        return embed

    def evaluate(self, x):
        return self.eval_fn(x), None

    def get_embed(self, x):
        return self.embed_fn(x)

    def get_exponential_normalized_logit(self, output_logits):
        normalized_logits = (output_logits - self.mean) / self.std
        return jnp.exp(normalized_logits)


class AMPProxy(EncoderProxy):
    def __init__(self):
        super().__init__("amp")


class UTRProxy(EncoderProxy):
    def __init__(self):
        super().__init__("utr")


class GFPProxy(EncoderProxy):
    def __init__(self):
        super().__init__("gfp")


class BitSeqProxy(Proxy):
    def __init__(self, num_modes: int, max_len: int, alphabet: list[str], mode_seed: int = 0) -> None:
        self.name = "BitSeq"

        self.num_modes = num_modes
        self.max_len = max_len
        self.alphabet = alphabet
        self.mode_seed = mode_seed
        self.embed_dim = max_len

        # Generate modes
        random.seed(mode_seed)
        vocab = ["00000000", "11111111", "11110000", "00001111", "00111100"]
        modes = ["".join(random.choices(vocab, k=max_len // len(vocab[0]))) for _ in range(self.num_modes)]
        modes = [jnp.array([bool(int(b)) for b in bit_string]) for bit_string in modes]
        self.modes = jnp.array(modes)
        self.modes_str = ["".join(row.astype(str)) for row in np.array(self.modes.astype(jnp.int32))]

    def evaluate(self, x):
        token_strs = ["".join([self.alphabet[idx] for idx in seq if idx > 2]) for seq in np.asarray(x)]
        dists = [[distance(x, y) for y in self.modes_str] for x in token_strs]
        dists = jnp.array(dists)
        return jnp.exp(1 - dists.min(axis=-1) / self.max_len)[..., None], dists

    def get_embed(self, x):
        token_strs = ["".join([self.alphabet[idx] for idx in seq if idx > 2]) for seq in np.asarray(x)]
        bit_lists = [jnp.array([int(bit) for bit in bitstring]) for bitstring in token_strs]
        return self.evaluate(x)[0], jnp.array(bit_lists)
