import os
from typing import List

import openai
from tqdm import tqdm
from loguru import logger
from dotenv import load_dotenv

from src.schema import (
    QA,
    Evidence,
    ConflictGraph,
)


CONFLICT_DETECTION_PROMPT_TEMPLATE = """
Is there a conflict between the following two contents?

Answer "yes" if there is a conflict, and "no" otherwise.

# Content A

{content_a}

# Content B

{content_b}

# Answer
""".strip()


class ConflictDetectionStage:
    def __init__(self) -> None:
        logger.info("Initializing ConflictDetectionStage.")

        load_dotenv()
        self.client = openai.OpenAI(api_key=os.environ["OPENAI_API_KEY"])

    def detect_conflict(self, evidences_a: List[Evidence], evidences_b: List[Evidence]) -> bool:
        # TODO: remove this mock impl
        import random
        return random.random() < 0.001

        # content_a = "\n\n".join([f"Q: {e.question}\nA: {e.answer}" for e in evidences_a]).strip()
        # content_b = "\n\n".join([f"Q: {e.question}\nA: {e.answer}" for e in evidences_b]).strip()

        # prompt = CONFLICT_DETECTION_PROMPT_TEMPLATE.format(content_a=content_a, content_b=content_b)

        # response = self.client.chat.completions.create(
        #     model="gpt-3.5-turbo-0125",
        #     messages=[{"role": "user", "content": prompt}],
        #     max_tokens=1,
        #     temperature=0.0,
        # )

        # result = response.choices[0].message.content.strip().lower()
        # if result == "yes":
        #     return True
        # elif result == "no":
        #     return False
        # else:
        #     raise ValueError(f"Unexpected response: {result}")

    def run(self, qa_dataset: List[QA]) -> ConflictGraph:
        # init variables
        n_nodes = 0
        answer_id_to_node_id = {}
        node_id_to_answer_id = {}
        adjacency_dict = {}

        # intra-conflict processing
        logger.info("Detecting intra-conflicts...")
        for i in range(len(qa_dataset)):
            qa_sample = qa_dataset[i]
            answers = qa_sample.answers
            start_node_id = n_nodes
            end_node_id = start_node_id + len(answers) - 1
            for j, answer in enumerate(answers):
                answer_id = answer.answer_id
                curr_node_id = start_node_id + j
                answer_id_to_node_id[answer_id] = curr_node_id
                node_id_to_answer_id[curr_node_id] = answer_id
                adjacency_dict[curr_node_id] = list(range(start_node_id, end_node_id + 1))
                adjacency_dict[curr_node_id].remove(curr_node_id)  # TODO: check here
            n_nodes = end_node_id + 1

        # inter-conflict processing
        logger.info("Detecting inter-conflicts...")
        for i in tqdm(range(len(qa_dataset))):
            for j in tqdm(range(i + 1, len(qa_dataset))):
                qa_sample_a = qa_dataset[i]
                qa_sample_b = qa_dataset[j]
                for answer_a in qa_sample_a.answers:
                    for answer_b in qa_sample_b.answers:
                        if self.detect_conflict(answer_a.evidences, answer_b.evidences):
                            node_id_a = answer_id_to_node_id[answer_a.answer_id]
                            node_id_b = answer_id_to_node_id[answer_b.answer_id]
                            adjacency_dict[node_id_a].append(node_id_b)
                            adjacency_dict[node_id_b].append(node_id_a)

        # generate conflict graph
        logger.info("Generating conflict graph...")
        conflict_graph = ConflictGraph(
            n_nodes=n_nodes,
            answer_id_to_node_id=answer_id_to_node_id,
            node_id_to_answer_id=node_id_to_answer_id,
            adjacency_dict=adjacency_dict,
        )

        return conflict_graph
