import logging
from functools import partial

import jax
import jax.numpy as jnp
import numpy as np
from jax import random
from tqdm.auto import tqdm

from . import transformer
from .caching import TransformerCache

logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)


@partial(jax.jit, static_argnums=(2, 3))
def sample_step(logits, random_key, top_p, temp):
    logits = logits / temp

    sorted_indices = jnp.argsort(logits, axis=-1)[..., ::-1]
    sorted_logits = jnp.sort(logits, axis=-1)[..., ::-1]

    cumulative_probs = jnp.cumsum(jax.nn.softmax(sorted_logits, axis=-1), axis=-1)

    cutoff = cumulative_probs > top_p
    cutoff = jnp.pad(cutoff[..., :-1], ((0, 0), (1, 0)), constant_values=False)

    masked_logits = jnp.where(cutoff, -jnp.inf, sorted_logits)

    sampled_indices_in_sorted = jax.random.categorical(
        random_key, masked_logits, axis=-1
    )

    next_token = jnp.take_along_axis(
        sorted_indices,
        sampled_indices_in_sorted[..., None],
        axis=-1,
    ).squeeze(-1)

    return next_token


def generate(
    max_new_tokens,
    tokenizer,
    params,
    config,
    random_key,
    run_fn=transformer.run,
    inputs=None,
    tokenized_inputs=None,
    temp=0.6,
    top_p=0.9,
    return_text=True,
    verbose=False,
):
    if inputs is None:
        if tokenized_inputs is None:
            raise ValueError("Either raw inputs or encoded inputs must be specified")

        batch_size = tokenized_inputs.shape[0]
        tokens = tokenized_inputs
    else:
        batch_size = len(inputs)
        encodings = tokenizer.encode_batch_fast(inputs)
        tokens = tokens = jnp.array([enc.ids for enc in encodings])

    tokens = jnp.concatenate(
        [tokens, tokenizer.pad_token_id * jnp.ones((batch_size, max_new_tokens))],
        axis=-1,
    ).astype(jnp.int32)
    positions = jnp.where(
        tokens != tokenizer.pad_token_id, jnp.arange(tokens.shape[-1]), -1
    )
    seq_lens = jnp.sum(positions >= 0, axis=-1)

    model_inputs = tokens

    cache = TransformerCache.create(
        positions=positions,
        model_config=config,
        dynamic=True,
    )

    eos_id = tokenizer.eos_token_id
    finished = jnp.any(tokens == eos_id, axis=-1)

    step_iter = range(max_new_tokens)
    if verbose:
        step_iter = tqdm(step_iter)

    for step in step_iter:
        logits, cache, *others = run_fn(model_inputs, cache, params, config)
        next_token_logits = logits[jnp.arange(batch_size), seq_lens - 1, :]

        step_key = random.fold_in(random_key, step)
        next_tokens = sample_step(
            logits=next_token_logits, random_key=step_key, top_p=top_p, temp=temp
        )

        next_tokens = jnp.where(finished, tokenizer.pad_token_id, next_tokens)

        batch_indices = jnp.arange(batch_size).astype(jnp.int32)
        tokens = tokens.at[batch_indices, seq_lens].set(next_tokens)
        finished = finished | (next_tokens == eos_id)

        model_inputs = next_tokens[..., None] if cache.dynamic else tokens
        cache = cache.roll()
        seq_lens += 1

        if bool(jnp.all(finished).item()):
            break

    tokens = jax.device_get(tokens)

    if return_text:
        ids = np.asarray(tokens)
        seq_lens_np = np.asarray(seq_lens)
        id2tok = tokenizer.id_to_token
        pad_id = tokenizer.pad_token_id
        bos_id = tokenizer.bos_token_id
        eos_id = tokenizer.eos_token_id

        # todo: find what slows down built-in decoding and replace
        def py_decode(batch_ids, lens):
            out = []
            for row, L in zip(batch_ids, lens):
                out.append(
                    "".join(
                        id2tok(int(i))
                        for i in row[:L]
                        if i not in (pad_id, bos_id, eos_id)
                    )
                )
            return out

        decoded_sequences = py_decode(ids, seq_lens_np)
        if len(decoded_sequences) != batch_size:
            logging.warning(
                f"No. decoded sequences differs from batch size: got{len(decoded_sequences)}, expected: {batch_size}"
            )
            logging.warning(f"Dumping first ten generated sequencers: {ids[:10]}")
        return decoded_sequences

    return tokens
