from typing import (
    List,
    Dict,
    Union,
)
from abc import (
    ABC,
    abstractmethod,
)

from torch.utils.data import DataLoader

from src.service.document_service import DocumentService
from src.service.memory_service import MemoryService
from src.service.generation_service import GenerationService
from src.schema import Prediction


MODEL_CONTEXT_LIMIT = 128_000
RESPONSE_MAX_TOKENS = 1_000
SAFETY_MARGIN_TOKENS = 1_000
MAX_INPUT_TOKENS = MODEL_CONTEXT_LIMIT - RESPONSE_MAX_TOKENS - SAFETY_MARGIN_TOKENS


def count_tokens_from_text(text: str) -> int:
    try:
        import tiktoken
        enc = tiktoken.get_encoding("cl100k_base")
        return len(enc.encode(text))
    except Exception:
        return int(len(text.split()) * 1.3) + 4


def count_tokens_from_messages(messages: List[Dict[str, str]]) -> int:
    total = 0
    for m in messages:
        content = m.get("content", "")
        role = m.get("role", "")
        total += count_tokens_from_text(role + ": " + content)
    return total


def truncate_text_by_tokens(text: str, max_tokens: int) -> str:
    if max_tokens <= 0:
        return ""
    try:
        import tiktoken
        enc = tiktoken.get_encoding("cl100k_base")
        tokens = enc.encode(text)
        if len(tokens) <= max_tokens:
            return text
        return enc.decode(tokens[:max_tokens])
    except Exception:
        words = text.split()
        approx_allow = max(1, int(max_tokens / 1.3))
        return " ".join(words[:approx_allow])


def parse_sub_queries(sub_queries: List[str]) -> str:
    parsed_sub_queries = "\n".join([f"{i}. {sq}" for i, sq in enumerate(sub_queries)]).strip()
    return f"You can consider the following sub_queries:\n{parsed_sub_queries}"


def parse_pers_graph(pers_graph: Dict[str, List[str]]) -> str:
    parsed_option_list = []
    for option, sub_queries in pers_graph.items():
        parsed_sub_queries = "\n".join([f"{i}. {sq}" for i, sq in enumerate(sub_queries)]).strip()
        parsed_option = f"For the option '{option}', you can consider the following sub_queries:\n{parsed_sub_queries}"
        parsed_option_list.append(parsed_option)
    parsed_options = "\n\n".join(parsed_option_list).strip()
    return f"You can consider the following options:\n\n{parsed_options}"


class BaseBatchProcessor(ABC):
    def __init__(
        self,
        generator_name: str,
    ) -> None:
        self.generator_name = generator_name
    
    def build_request(
        self,
        data_loader: DataLoader,
        predictions: List[Prediction],
        doc_service: DocumentService,
        mem_service: MemoryService,
        generation_service: GenerationService,
        generation_strategy: str,
        generator_temperature: float,
        semantic_memory_enabled: bool,
    ) -> List[Dict]:
        requests = []
        for i, qa in enumerate(data_loader):
            qa_id = qa.qa_id
            context = qa.context
            question = qa.question
            answers = qa.answers

            prediction = predictions[i]
            assert qa_id == prediction.qa.qa_id

            pred_answers = prediction.pred_answers

            for answer, pred_answer in zip(answers, pred_answers):
                assert answer.answer_id == pred_answer.answer_id

                answer_id = answer.answer_id

                ret_document_ids = pred_answer.ret_document_ids
                if ret_document_ids is None:
                    ret_document_contents = []
                else:
                    ret_documents = doc_service.get_documents(doc_ids=ret_document_ids)
                    ret_document_contents = [doc.content for doc in ret_documents]

                ret_memory_ids = pred_answer.ret_memory_ids
                if ret_memory_ids is None:
                    ret_memory_contents = pred_answer.ret_memory_contents
                    if ret_memory_contents is None:
                        ret_memory_contents = []
                else:
                    ret_memories = mem_service.get_memories(memory_ids=ret_memory_ids)
                    ret_memory_contents = [mem.content for mem in ret_memories]

                if semantic_memory_enabled:
                    user_id = answer.user_id
                    profile_content = mem_service.get_semantic_memory_prompt(user_id=user_id)
                else:
                    profile_content = "No information available"

                if generation_strategy == "naive":
                    forwarding_content = "No information available"
                elif generation_strategy == "gen_sq":
                    forwarding_content = parse_sub_queries(pred_answer.gen_sub_queries)
                elif generation_strategy == "oracle_sq":
                    forwarding_content = parse_sub_queries(answer.oracle_sub_queries)
                elif generation_strategy == "gen_pg":
                    forwarding_content = parse_pers_graph(pred_answer.gen_pers_graph)
                elif generation_strategy == "oracle_pg":
                    forwarding_content = parse_pers_graph(answer.oracle_pers_graph)
                else:
                    raise ValueError(f"Unknown forwarding strategy: {generation_strategy}")

                prompt = generation_service.get_prompt(
                    context=context,
                    question=question,
                    document_contents=ret_document_contents,
                    memory_contents=ret_memory_contents,
                    profile_content=profile_content,
                    forwarding_content=forwarding_content,
                )

                final_messages = [{"role": "user", "content": prompt}]
                total_input_tokens = count_tokens_from_messages(final_messages)
                if total_input_tokens > MAX_INPUT_TOKENS:
                    prompt = truncate_text_by_tokens(prompt, MAX_INPUT_TOKENS)
                
                request = {
                    "custom_id": answer_id,
                    "method": "POST",
                    "url": "/v1/chat/completions",
                    "body": {
                        "model": self.generator_name,
                        "messages": [{"role": "user", "content": prompt}],
                        "temperature": generator_temperature,
                        "max_tokens": 1000,
                        "metadata": {"custom_id": answer_id},
                        "store": True,
                    }

                }
                requests.append(request)
        return requests
    
    @abstractmethod
    def infer(self, request_path: str) -> Union[List, Dict]:
        raise NotImplementedError()
    
    @abstractmethod
    def parse_response(
        self,
        predictions: List[Prediction],
        batch_path: str,
    ) -> List[Prediction]:
        raise NotImplementedError()
