from transformers import AutoModelForCausalLM, AutoTokenizer
from eagle.model.ea_model import EaModel
from datasets import load_dataset
import json
import torch
import tqdm


def get_dataset(data_kwargs):
    dataset = load_dataset(
        data_kwargs["ds_name"],
        name=data_kwargs.get("subset", None),
        split=data_kwargs.get("split", None),
    )
    dataset = dataset.select(range(data_kwargs["num_prompts"]))
    if "prompt_constructor" in data_kwargs:

        def t(x):
            x["prompt"] = data_kwargs["prompt_constructor"](x)
            return x

        dataset = dataset.map(t)
    return dataset


def one_batch(
    prompts,
    tokenizer,
    target_model,
    draft_model,
    data_kwargs,
    generation_kwargs,
    device,
    is_eagle,
):
    in_dict = tokenizer(
        prompts,
        return_tensors="pt",
        padding=True,
        padding_side="left",
        truncation=True,
        max_length=data_kwargs.get("max_input_length", None),
    )
    input_ids = in_dict["input_ids"].to(device)
    input_mask = in_dict["attention_mask"].to(device)

    # generate
    if is_eagle:
        out = target_model.base_model.generate(
            input_ids,
            attention_mask=input_mask,
            max_new_tokens=generation_kwargs["max_tokens"],
            temperature=generation_kwargs["temperature"],
            do_sample=True if generation_kwargs["temperature"] > 0 else False,
            top_k=0,
            top_p=1,
            num_return_sequences=1,
            output_scores=True,
            output_hidden_states=True,
            return_dict_in_generate=True,
            use_cache=False,
        )
        last_hidden_states = out.hidden_states[-1][-1]
    else:
        out = target_model.generate(
            input_ids,
            attention_mask=input_mask,
            max_new_tokens=generation_kwargs["max_tokens"],
            temperature=generation_kwargs["temperature"],
            do_sample=True if generation_kwargs["temperature"] > 0 else False,
            top_k=0,
            top_p=1,
            num_return_sequences=1,
            output_scores=True,
            return_dict_in_generate=True,
        )

    output_ids = out.sequences[:, input_ids.shape[-1] :]
    output_str = tokenizer.batch_decode(output_ids, skip_special_tokens=True)

    # token after eos need to be discarded
    discard_mask = (output_ids.eq(tokenizer.eos_token_id).cumsum(dim=-1) > 0).roll(
        shifts=1, dims=-1
    )
    discard_mask[..., 0] = 0
    valid_mask = ~discard_mask

    scores = out.scores
    logits_p = torch.stack(scores, dim=-2)

    # compute scores from draft model
    full_input_ids = torch.cat([input_ids, output_ids], dim=-1)[..., :-1]
    full_input_mask = torch.cat([input_mask, valid_mask], dim=-1)[..., :-1]
    if is_eagle:
        all_hidden_states_used = last_hidden_states[:, :-1, :]
        out_small_hidden = target_model.ea_layer(
            all_hidden_states_used,
            input_ids=full_input_ids[:, 1:],
            attention_mask=full_input_mask[:, :-1],
            use_cache=False,
        )
        logits_q = target_model.base_model.lm_head(
            out_small_hidden[:, input_ids.shape[-1] - 2 :]
        )
        return logits_q.float(), logits_p.float(), valid_mask
    else:
        draft_out = draft_model(
            full_input_ids,
            attention_mask=full_input_mask,
            return_dict=True,
        )

        logits_q = draft_out.logits[:, input_ids.shape[-1] - 1 :]
        return logits_q, logits_p, valid_mask


def get_logits_generator(
    data_kwargs, model_kwargs, generation_kwargs, reproducibility_kwargs
):
    device = torch.device("cuda")
    print(
        f"device: {device}, torch.cuda.is_available(): {torch.cuda.is_available()}, data_kwargs: {data_kwargs}, model_kwargs: {model_kwargs}, generation_kwargs: {generation_kwargs}, reproducibility_kwargs: {reproducibility_kwargs}"
    )
    is_eagle = "EAGLE" in model_kwargs["draft_model_str"]
    if is_eagle:
        target_model = EaModel.from_pretrained(
            base_model_path=model_kwargs["target_model_str"],
            ea_model_path=model_kwargs["draft_model_str"],
            torch_dtype=(
                torch.float16
                if "Qwen2" not in model_kwargs["target_model_str"]
                else torch.bfloat16
            ),
            low_cpu_mem_usage=True,
            device_map=device,
            total_token=-1,
        )
        target_model.eval()
        draft_model = None
        tokenizer = target_model.tokenizer
    else:
        target_model = AutoModelForCausalLM.from_pretrained(
            model_kwargs["target_model_str"],
            low_cpu_mem_usage=True,
            device_map=device,
        )
        draft_model = AutoModelForCausalLM.from_pretrained(
            model_kwargs["draft_model_str"],
            low_cpu_mem_usage=True,
            device_map=device,
        )
        target_model.eval()
        draft_model.eval()
        tokenizer = AutoTokenizer.from_pretrained(model_kwargs["target_model_str"])

    if tokenizer.pad_token_id is None:
        tokenizer.pad_token = tokenizer.eos_token

    dataset = get_dataset(data_kwargs)
    batched_dataset = dataset.iter(batch_size=generation_kwargs["batch_size"])

    for idx, batch in tqdm.tqdm(
        enumerate(batched_dataset),
        total=data_kwargs["num_prompts"] // generation_kwargs["batch_size"],
    ):
        seed = reproducibility_kwargs["seed"] + idx
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)

        prompts = batch["prompt"]
        # print(prompts)

        yield one_batch(
            prompts,
            tokenizer,
            target_model,
            draft_model,
            data_kwargs,
            generation_kwargs,
            device,
            is_eagle,
        )

