
from concurrent.futures import ThreadPoolExecutor, as_completed
from parse import parse
import json

from lm_polygraph.generation_metrics.openai_fact_check import *
from lm_polygraph.stat_calculators.extract_claims import *
from synthetic_dataset_generation.utils.deepseek_chat import DeepSeekChat


VERSION = 'correctness_redundancy'

PROMPT1_TEMPLATE = {
    'correctness_redundancy': r'''You are given a problem, a ground-truth solution, and a step-by-step student solution. Your task is to analyze each step in the student’s solution to determine whether it is both correct and informative. 

Correctness: if a step is correct, it contains no mistakes in calculation and logic.
Informative: if a step is informative, it provides new information that is not a paraphrase of existing context and previous steps, and it contributes towards getting closer to the answer.

Instructions:
- Carefully examine each student step for logical/calculation errors or unnecessary/redundant reasoning.
- If all steps are correct and they lead to the same final answer as the ground-truth solution, conclude that there are no errors.
- If any step is incorrect (contains logical or calculation error) or non-informative (redundant or has no contribution to the final answer), identify and report those specific steps with an explanation.

PROBLEM:
{problem}

GROUND-TRUTH SOLUTION:
{answer}

STUDENT'S SOLUTION STEPS:
{steps}

Now, please evaluate whether the student’s steps are correct and logical.''',
    'correctness': r'''You are given a problem, a ground-truth solution, and a step-by-step student solution. Your task is to analyze each step in the student’s solution to determine whether it is both logically correct and relevant.

Instructions:
- Carefully examine each student step for logical errors or unnecessary/redundant reasoning.
- If all steps are correct and they lead to the same final answer as the ground-truth solution, conclude that there are no errors.
- If any step contains an error that would prevent the student from reaching the correct solution, identify and report those specific steps with an explanation.

PROBLEM:
{problem}

GROUND-TRUTH SOLUTION:
{answer}

STUDENT'S SOLUTION STEPS:
{steps}

Now, please evaluate whether the student’s steps are correct and logical.'''
}


PROMPT2_TEMPLATE = {
    'correctness_redundancy': r'''
You are given:
- A problem
- A student's step-by-step solution (as a Python list of string steps)
- An assessment of student's solution

Your task:
Output a json object with the following fields:
- "correctness": a list of 0/1 values, where 1 (correct) indicates the step contains no mistakes in calculation and logic; otherwise 0 (incorrect).
- "informativeness": a list of 0/1 values, where 1 (informative) means the step provides new information that is not a paraphrase of existing context and previous steps, and it contributes towards getting closer to the answer. Otherwise 0 (non-informative).

Important:
- Output only the json object with the fields "correctness" and "informativeness", nothing else.
- The correctness list must have correctness labels for all steps and the final answer (in this case, list length should be {list_length}).
- The informativeness list must have one fewer entry than the number of steps (i.e., {list_length_1}), because it should only score the reasoning steps and NOT the final answer step.

PROBLEM:
{problem}

STUDENT'S SOLUTION STEPS:
{steps}

ASSESSMENT OF STUDENT SOLUTION STEPS:
{reply}

OUTPUT JSON:
''',
    'correctness': r"""
You are given:
- A problem
- A student's step-by-step solution (as a Python list of string steps)
- An assessment of student's solution

Your task:
Output a single Python list where each element is:
- 1 if the corresponding step is correct
- 0 if the step is incorrect

Important:
- Output only the list, nothing else.
- The list must have the same length as the number of steps.

PROBLEM:
{problem}

STUDENT'S SOLUTION STEPS:
{steps}

ASSESSMENT OF STUDENT SOLUTION STEPS:
{reply}

OUTPUT LIST:
""",
}



class StepFactCheck(GenerationMetric):
    def __init__(
            self,
            prompt_file: str,
            cache_path: str = "~/.cache",
            model: str = 'deepseek-reasoner',
            api_key: str | None = None,
            progress_bar: bool = True,
            n_threads: int = 1,
            wait_times: tuple = (5, 10, 30, 60, 120),
            version: str = VERSION,
            label_type: str = 'correctness',
    ):
        super().__init__(["input_texts", "claims"], "claim")

        with open(prompt_file, 'r') as f:
            self.prompt = f.read()

        if version in ['correctness_redundancy']:
            self.json_output = True
        else:
            self.json_output = False

        if 'deepseek' in model:
            self.chat = DeepSeekChat(cache_path,  model=model, api_key=api_key, wait_times=wait_times)
        else:
            self.chat = OpenAIChat(model, cache_path=cache_path)

        self.label_type = label_type
        # use this for OpenAI
        # self.chat = DeepSeekChat(api_base=None, model='gpt-4o', cache_path=cache_path, api_key=api_key, wait_times=wait_times)

        self.progress_bar = progress_bar
        self.n_threads = n_threads
        self.version = version

    def __str__(self):
        return "StepFactCheck" + "_" + self.label_type

    def parse_problem(self, input_text: str):
        try:
            return parse(self.prompt, input_text).named['q']
        except Exception as e:
            # For run_extract_verify_claims.py, input texts are raw questions without prompt
            return input_text

    def prompt1(self, input_text: str, claims: list[Claim], answer: str) -> str:
        problem = self.parse_problem(input_text)
        steps = '\n'.join([cl.claim_text.strip() for i, cl in enumerate(claims)])
        return PROMPT1_TEMPLATE[self.version].format(problem=problem, answer=answer,
                                                                                    steps=steps)

    def prompt2(self, input_text: str, claims: list[Claim], answer: str, reply: str) -> str:
        problem = self.parse_problem(input_text)
        steps = [cl.claim_text.strip() for i, cl in enumerate(claims)]
        if self.json_output:
            return PROMPT2_TEMPLATE[self.version].format(problem=problem, steps=steps, reply=reply, list_length=len(steps), list_length_1=len(steps) - 1)
        else:
            return PROMPT2_TEMPLATE[self.version].format(problem=problem, steps=steps, reply=reply)

    def parse_reply(self, reply: str) -> list[int] | None:
        if 'all steps are correct' in reply.lower():
            return []
        orig_reply = reply
        reply = reply.strip().replace(' ', '').replace('Step', '')
        if '```python' in reply:
            reply = reply.split('```python')[-1].split('```')[0].strip()
        if reply.startswith('[') and reply.endswith(']'):
            reply = reply[1:-1]
        try:
            return [int(x) for x in reply.split(',')]
        except Exception as e:
            log.warning('Skipping text, because could not parse DeepSeek reply: {}'.format(orig_reply))
            return None

    def _score_single(self, args: tuple[list, str, str]) -> list:
        claims, input_text, answer = args
        q1 = self.prompt1(input_text, claims, answer)
        # print(q1)
        # import pdb; pdb.set_trace()
        reply = self.chat.ask(q1, json_output=False)
        # print(reply)
        # import pdb; pdb.set_trace()
        q2 = self.prompt2(input_text, claims, answer, reply)
        # print(q2)
        # import pdb; pdb.set_trace()
        reply = self.chat.ask(q2, json_output=self.json_output)
        # print(reply)
        # import pdb; pdb.set_trace()
        if self.json_output:
            try:
                json_reply = json.loads(reply)
                correctness_labels = json_reply['correctness']
                informativeness_labels = json_reply['informativeness']
            except Exception as e:
                log.warning(f"Skipping text, because could not parse DeepSeek reply: {reply}")
                return [np.nan for _ in range(len(claims))]
        else:
            correctness_labels: list[int] | None = self.parse_reply(reply)
            informativeness_labels = None

        if self.label_type == 'correctness':
            claim_labels = correctness_labels
        elif self.label_type == 'informativeness':
            claim_labels = informativeness_labels
        else:
            raise ValueError(f"Label type {self.label_type} not supported")

        if claim_labels is None:
            return [np.nan for _ in range(len(claims))]  # will be skipped at evaluation
        if len(claim_labels) + 1 == len(claims):
            claim_labels.append(np.nan)  # last answer is undefined
        if len(claim_labels) != len(claims):
            # import pdb; pdb.set_trace()
            log.warning(f"Prompt 2: {q2}")
            log.warning(
                'Skipping text, because of inconsistend number of '
                'labels in DeepSeek reply: expected {}, got {}'.format(len(claims), reply))
            return [np.nan for _ in range(len(claims))]  # will be skipped at evaluation
        
        return [
            (
                np.nan if len(claims[i].aligned_token_ids) == 0 or np.isnan(claim_labels[i]) else
                1 if claim_labels[i] == 0 else
                0
            ) for i in range(len(claims))
        ]

    def __call__(
            self,
            stats: Dict[str, np.ndarray],
            target_texts: List[str],
    ) -> list:
        input_texts = stats["input_texts"]

        if "answers" in stats.keys():
            target_texts = stats["answers"]

        all_inputs = [
            (claims, input_text, answer)
            for input_text, claims, answer in zip(input_texts, stats["claims"], target_texts)
        ]

        with ThreadPoolExecutor(max_workers=self.n_threads) as executor:
            futures = [executor.submit(self._score_single, item) for item in all_inputs]
            claim_labels = []
            for future in tqdm(futures, desc=f"Verifying claims ({self.label_type})", disable=not self.progress_bar):
                claim_labels.append(future.result())

        return claim_labels

