import json
import torch
import pathlib

from copy import deepcopy
from typing import List, Tuple, Optional, TypeVar, Type
from pydantic.dataclasses import dataclass
from transformers import LlamaTokenizer

import sys
sys.path.append('/home/wutong1/PoSE')
from src.my_configuration_llama import LlamaConfig
from src.train_pose import smart_tokenizer_and_embedding_resize


DEFAULT_PAD_TOKEN = "[PAD]"
DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_BOS_TOKEN = "<s>"
DEFAULT_UNK_TOKEN = "</s>"


def read_file(file_name):
    with open(file_name) as f:
        if 'txt' in file_name:
            json_data = f.read()
            json_data = "[" + json_data.replace("}\n{", "},\n{") + "]"
            all_examples = json.loads(json_data)

    return all_examples
        

def load_model_tokenizer(use_flash_attn, model_name_or_path, model_max_position_embeddings, rope_scaling_factor, rope_scaling_type):

    if use_flash_attn:
        from src.my_flash_modeling_llama import LlamaForCausalLM
    else:
        from src.my_modeling_llama import LlamaForCausalLM

    Config, CausalLM, Tokenizer = LlamaConfig, LlamaForCausalLM, LlamaTokenizer
    config = Config.from_pretrained(model_name_or_path, torch_dtype=torch.bfloat16)
    scaled_max_position_embeddings=int(model_max_position_embeddings * rope_scaling_factor)

    if config.rope_scaling is None:
        if rope_scaling_type is not None:
            config.rope_scaling={"type": rope_scaling_type, "factor": rope_scaling_factor}
            config.max_position_embeddings=scaled_max_position_embeddings
            if rope_scaling_type == "yarn":
                config.rope_scaling["original_max_position_embeddings"] = model_max_position_embeddings
    print(config)
    
    print(f"load model from {model_name_or_path}")
    model = CausalLM.from_pretrained(
        pretrained_model_name_or_path=model_name_or_path, 
        config=config, torch_dtype=torch.bfloat16)
    model.to("cuda")
    model.eval()

    print("load tokenizer")
    tokenizer = Tokenizer.from_pretrained(model_name_or_path, use_fast=True)

    if tokenizer.pad_token is None:
        smart_tokenizer_and_embedding_resize(
            special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
            tokenizer=tokenizer,
            model=model,
        )
        tokenizer.add_special_tokens(
            {
                "eos_token": DEFAULT_EOS_TOKEN,
                "bos_token": DEFAULT_BOS_TOKEN,
                "unk_token": DEFAULT_UNK_TOKEN,
            }
        )

    return model, tokenizer


T = TypeVar("T")

@dataclass(frozen=True)
class Document:
    title: str
    text: str
    id: Optional[str] = None
    score: Optional[float] = None
    hasanswer: Optional[bool] = None
    isgold: Optional[bool] = None
    original_retrieval_index: Optional[int] = None

    @classmethod
    def from_dict(cls: Type[T], data: dict) -> T:
        data = deepcopy(data)
        if not data:
            raise ValueError("Must provide data for creation of Document from dict.")
        id = data.pop("id", None)
        score = data.pop("score", None)
        # Convert score to float if it's provided.
        if score is not None:
            score = float(score)
        return cls(**dict(data, id=id, score=score))


def get_qa_prompt(
    question: str, documents: List[Document], mention_random_ordering: bool, query_aware_contextualization: bool
):
    if not question:
        raise ValueError(f"Provided `question` must be truthy, got: {question}")
    if not documents:
        raise ValueError(f"Provided `documents` must be truthy, got: {documents}")

    if mention_random_ordering and query_aware_contextualization:
        raise ValueError("Mentioning random ordering cannot be currently used with query aware contextualization")

    if mention_random_ordering:
        prompt_filename = "qa_ordered_randomly.prompt"
    elif query_aware_contextualization:
        prompt_filename = "qa_with_query_aware_contextualization.prompt"
    else:
        prompt_filename = "qa.prompt"

    with open(PROMPTS_ROOT / prompt_filename) as f:
        prompt_template = f.read().rstrip("\n")

    # Format the documents into strings
    formatted_documents = []
    for document_index, document in enumerate(documents):
        formatted_documents.append(f"Document [{document_index+1}](Title: {document.title}) {document.text}")
    return prompt_template.format(question=question, search_results="\n".join(formatted_documents))


def get_closedbook_qa_prompt(question: str):
    if not question:
        raise ValueError(f"Provided `question` must be truthy, got: {question}")
    with open(PROMPTS_ROOT / "closedbook_qa.prompt") as f:
        prompt_template = f.read().rstrip("\n")

    return prompt_template.format(question=question)


PROMPTS_ROOT = (pathlib.Path(__file__).parent / "prompts").resolve()
def get_kv_retrieval_prompt(
    data: List[Tuple[str, str]],
    key: str,
    query_aware_contextualization: bool = False,
):
    if not data:
        raise ValueError(f"Provided `data` must be truthy, got: {data}")
    if not key:
        raise ValueError(f"Provided `key` must be truthy, got: {key}")
    if key not in [x[0] for x in data]:
        raise ValueError(f"Did not find provided `key` {key} in data {data}")
    if len(data) != len(set([x[0] for x in data])):
        raise ValueError(f"`data` has duplicate keys: {data}")
    if len(data) < 2:
        raise ValueError(f"Must have at least 2 items in data: {data}")

    if query_aware_contextualization:
        with open(PROMPTS_ROOT / "kv_retrieval_with_query_aware_contextualization.prompt") as f:
            prompt_template = f.read().rstrip("\n")
    else:
        with open(PROMPTS_ROOT / "kv_retrieval.prompt") as f:
            prompt_template = f.read().rstrip("\n")

    # Format the KV data into a string
    formatted_kv_records = ""
    for index, record in enumerate(data):
        start_character = "{" if index == 0 else " "
        data_string = f'"{record[0]}": "{record[1]}"'
        end_character = ",\n" if index != len(data) - 1 else "}"
        formatted_kv_records += start_character + data_string + end_character

    return prompt_template.format(formatted_kv_records=formatted_kv_records, key=key)
