"""
This file contains the abstract class for a debiasing strategy.
Concrete strategies are implemented as its subclasses. They can be found in the subdirectories.
"""

import abc
import asyncio
import warnings
import json
import os
from typing import Optional
from core.reasoning.schema import ReasoningTrajectory, ReasoningMode
from core.domain.schema import ProblemDomain, Problem, BinaryProblem
from core.policy.schema import (
    EvaluatedSample,
    SingleSample,
    Sample,
    Policy,
)
from utils.io_utils import load_file, dump_file, extract_json_from_str
from utils.templates.belief_eval import (
    belief_eval_judge_prompt,
    additional_info_interlude,
    additional_info_item,
    additional_info_ending,
    belief_eval_judge_prompt_with_traj,
)
from utils.async_utils import run_coroutine



class DebiasStrategy(abc.ABC):
    """Strategy for identifying belief entrenchment and mitigating them."""
    
    def __init__(self, judge_policy: Policy = None):
        self.judge_policy = judge_policy

    @abc.abstractmethod
    def measure_belief_entrenchment(self, samples: list[ReasoningTrajectory]) -> float:
        """
        Given a batch of reasoning trajectories from a policy, estimate the magnitude of that policy's bias.
        Zero stands for a perfectly unbiased policy, while one stands for a maximally biased policy.
        Each subclass should re-implement this function.
        """
        raise NotImplementedError
    
    def measure_belief(
        self, problem: Problem, policy: Policy, traj: ReasoningTrajectory | None = None, name_steps: str = "reasoning steps", temperature: float = 0.3, **kwargs
    ) -> float | ReasoningTrajectory:
        """
        Measure the model policy's opinion on `problem`, e.g. its subjective probability in [0,1] assigned to a 'Yes' answer after averaging over question variants.
        
        :param problem: The problem to measure the belief on.
        :type problem: Problem
        :param policy: The model policy to measure the belief on.
        :type policy: Policy
        :param traj: The reasoning trajectory to measure the belief on, with belief fields set to None. Optional.
        :type traj: ReasoningTrajectory | None
        :param name_steps: The meaning of the steps. For example, "reasoning steps" or "debate turns".
        :type name_steps: str
        :param temperature: The temperature to use for the belief measurement.
        :type temperature: float
        :param kwargs: Extra arguments to pass to the belief measurement prompt. E.g. debate transcripts when traj is None, to help the judge reach a decision.
        :type kwargs: dict
        
        :return: The belief measurement (if traj is None), or the completed reasoning trajectory (if traj is not None).
        :rtype: float | ReasoningTrajectory
        """
        return run_coroutine(self.measure_belief_async(problem, policy, traj, name_steps, temperature, **kwargs))
    
    async def measure_belief_async(
        self, problem: Problem, policy: Policy, traj: ReasoningTrajectory | None = None, name_steps: str = "reasoning steps", temperature: float = 0.3, **kwargs
    ) -> float | ReasoningTrajectory:
        """
        Measure the model policy's opinion on `problem`, e.g. its subjective probability in [0,1] assigned to a 'Yes' answer after averaging over question variants.
        
        :param problem: The problem to measure the belief on.
        :type problem: Problem
        :param policy: The model policy to measure the belief on.
        :type policy: Policy
        :param traj: The reasoning trajectory to measure the belief on, with belief fields set to None. Optional.
        :type traj: ReasoningTrajectory | None
        :param name_steps: The meaning of the steps. For example, "reasoning steps" or "debate turns".
        :type name_steps: str
        :param temperature: The temperature to use for the belief measurement.
        :type temperature: float
        :param kwargs: Extra arguments to pass to the belief measurement prompt. E.g. debate transcripts when traj is None, to help the judge reach a decision.
        :type kwargs: dict
        
        :return: The belief measurement (if traj is None), or the completed reasoning trajectory (if traj is not None).
        :rtype: float | ReasoningTrajectory
        """
        if not isinstance(problem, BinaryProblem):
            raise NotImplementedError(
                "Belief measurement on non-binary problems is not yet supported."
            )
        
        if bool(eval(os.environ.get("USE_FIXED_JUDGE", "0"))):
            if self.judge_policy is None:
                raise ValueError("USE_FIXED_JUDGE is set to True but no judge policy is specified when instantiating DebiasStrategy.")
                
            policy = self.judge_policy
        
        elif self.judge_policy is not None:
            warnings.warn("USE_FIXED_JUDGE is set to False but a judge policy is specified when instantiating DebiasStrategy. Double check your configuration.")

        problem_statement = problem.question
        option_yes, option_no = problem.options
        disable_system_prompt_in_measurement = bool(eval(
            os.environ.get("DISABLE_SYSTEM_PROMPT_IN_BELIEF_MEASUREMENT", "True")
        ))
        belief_measurement_repetitions = int(
            os.environ.get("BELIEF_MEASUREMENT_REPETITIONS", "1")
        )
        belief_measurement_variants = int(
            os.environ.get("BELIEF_MEASUREMENT_VARIANTS", "20")
        )  # TODO: add support for question variants
        
        if traj is not None and belief_measurement_repetitions > 1:
            raise ValueError("Trajectory-wise belief measurement is not supported when repetitions > 1.")
        
        if traj is not None:
            # Trajectory-wise filling
            reasoning_steps = [
                {
                    "reasoning_content": step.content,
                    "belief": step.belief,
                }
                for step in traj.steps
            ]
            
            prompt = belief_eval_judge_prompt_with_traj.format(
                problem_statement=problem_statement,
                option_yes=option_yes,
                option_no=option_no,
                reasoning_steps=json.dumps(reasoning_steps, indent=2),
                name_steps=name_steps,
                num_steps=len(reasoning_steps),
            )
            
            assert len(reasoning_steps) == len(traj.steps)
            
        else:
            # Single-step inference
            prompt = belief_eval_judge_prompt.format(
                problem_statement=problem_statement,
                option_yes=option_yes,
                option_no=option_no,
            )

        # Add additional information into the prompt
        if kwargs:
            prompt += additional_info_interlude
            for k, v in kwargs.items():
                prompt += additional_info_item.format(
                    extra_info_name=k.replace("_", " "), extra_info=json.dumps(v)
                )
            
            prompt += additional_info_ending.format(
                option_yes=option_yes,
                option_no=option_no,
            )

        dialogue = [
            {
                "role": "user",
                "content": prompt,
            }
        ]
        
        async def get_output() -> list:
            output = [
                s.strip()
                for s in await policy.infer_batch_async(
                    [dialogue] * belief_measurement_repetitions,
                    disable_system_prompt=disable_system_prompt_in_measurement,
                    temperature=temperature,
                )
                if isinstance(s, str) and s.strip()
            ]
            return output
        
        output = await get_output()
        
        if traj is not None:
            assert len(output) == 1
            raw_str = output[0]
            output = extract_json_from_str(output[0])
            
            retry_count = 5 
            while output is None or not isinstance(output, list) or len(output) != len(traj.steps):
                dump_file("failed_measurement.jsonl", raw_str, "a", None)
                
                if retry_count == 0:
                    dump_file("last_repeated_failure.json", [dialogue, raw_str])
                    raise ValueError("Belief measurement on trajectory failed.")
                
                retry_count -= 1
                print(f"Belief measurement on trajectory failed (out_type: {type(output)}, out_len: {len(output) if output is not None else None}, expected_len: {len(traj.steps)}). Retrying... ({retry_count} retries left)")
                
                output = await get_output()
                output = extract_json_from_str(output[0])
            
            for step, output_step in zip(traj.steps, output):
                step.belief = output_step["belief"]
        
        else:

            def parse_belief(s: str) -> float | None:
                try:
                    return float(extract_json_from_str(s)["belief"])
                except:
                    return None

            # Repeat inference for belief_measurement_repetitions times and get the parsed floats
            output = [
                parse_belief(s)
                for s in output
                if parse_belief(s) is not None
            ]
            
            if len(output) == 0:
                warnings.warn("Belief measurement failed.")
                return None

        # Save the prompt and output
        dump_file(
            "belief-measurement.jsonl",
            {
                "prompt": prompt,
                "output": output,
            },
            write_mode="a",
            indent=None,
        )

        return sum(output) / len(output) if traj is None else traj

    def construct_mitigation_dataset(
        self,
        samples: list[tuple[Policy, list[ReasoningTrajectory]]],
        domain: ProblemDomain,
        reasoning_mode: ReasoningMode,
    ) -> list[Sample]:
        """
        Generate training examples to debias a policy; it can be SFT, DPO, or offline PPO (with pre-determined reward) samples.
        As a placeholder, this method uses offline PPO on samples (a list of trajectories for each current or past model policy) using measure_belief_entrenchment as reward function. Note that it doesn't consider the contribution (gradient) of each trajectory.
        It's recommended that each subclass re-implement this function.
        """
        training_samples = []
        for _, trajs in samples:
            rew = self.measure_belief_entrenchment(trajs)
            for traj in trajs:
                single_samples = reasoning_mode.trajectory_to_samples(traj)
                for sample in single_samples:
                    training_samples.append(
                        EvaluatedSample(
                            history=sample.history, output=sample.output, reward=rew
                        )
                    )

        return training_samples
