import math
from dataclasses import dataclass
from typing import List

import torch
from tqdm import tqdm

from utils import ModelHook, jsonlload, load_model


@dataclass
class InputQA:
    prompt_token: List[int]
    answer_token: List[int]


def get_hidden_state(
    model,
    tokenizer,
    all_qa: List[InputQA],
    key_list: List[str],
    batch_size=1,
    save_device="cpu",
    need_prompt=True,
):
    # hidden state
    hooked_model = ModelHook(model, tokenizer)
    hooked_model.hook_model_input_output(hook_name_list=key_list)
    
    # init
    if tokenizer.pad_token_id is not None:
        pad_token_id = tokenizer.pad_token_id
    elif tokenizer.eos_token_id is not None:
        pad_token_id = tokenizer.eos_token_id
    else:
        try:
            pad_token_id = tokenizer.convert_tokens_to_ids("<|endoftext|>")
        except Exception:
            raise ValueError(
                "The pad token id cannot be determined. Please check whether the tokenizer defines the pad token or eos token"
            )

    all_emb = {key: [] for key in key_list}

    # batch
    print(f"batch size: {batch_size}, batch num: {math.ceil(len(all_qa) / batch_size)}")
    for start_index in tqdm(range(0, len(all_qa), batch_size)):
        batch_qa: List[InputQA] = all_qa[start_index : start_index + batch_size]
        cur_batch_size = len(batch_qa)

        batch_input_ids = [qa.prompt_token + qa.answer_token for qa in batch_qa]

        max_length = max(len(seq) for seq in batch_input_ids)

        input_ids = torch.tensor(
            [seq + [pad_token_id] * (max_length - len(seq)) for seq in batch_input_ids]
        )
        attention_mask = torch.tensor(
            [[1] * len(seq) + [0] * (max_length - len(seq)) for seq in batch_input_ids]
        )

        input_ids = input_ids.to(model.device)
        attention_mask = attention_mask.to(model.device)

        hooked_model.clear_io()
        hooked_model.model.generate(
            input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=1
        )

        for key in key_list:
            key_out = hooked_model.output_dict[key][0][0]

            for ind in range(cur_batch_size):
                start_index = 0 if need_prompt else len(batch_qa[ind].prompt_token)
                all_emb[key].append(
                    key_out[
                        ind,
                        start_index : len(batch_input_ids[ind]),
                        :,
                    ].to(save_device)
                )

    hooked_model.remove_all_hook()

    return all_emb


def main(data, model, tokenizer, save_path):
    model.eval()
    all_qa = [
        InputQA([1, 2, 3], [4, 5, 6]),
        InputQA([21, 31], [41, 42, 43]),
        InputQA([1, 2, 3], [4, 5]),
    ]
    all_hallucination_flag = [True, False]

    layers = len(model.model.layers)
    key_list = [
        # f"model.layers.{layers - 3}",
        # f"model.layers.{layers - 2}",
        f"model.layers.{layers - 1}",
    ]
    all_emb = get_hidden_state(
        model=model,
        tokenizer=tokenizer,
        all_qa=all_qa,
        key_list=key_list,
        batch_size=10,
    )

    # save
    res = {"all_emb": all_emb, "all_hallucination_flag": all_hallucination_flag}
    torch.save(
        res,
        save_path,
    )

