import os
import re
import tqdm
import json
import time
import asyncio
import warnings
import numpy as np
from copy import deepcopy
from typing import Optional
from utils.templates.debate import initial_prompt
from utils.judge_manipulation_utils import construct_prompting_context
from core.algo.schema import DebiasStrategy
from core.reasoning.schema import ReasoningStep, ReasoningTrajectory, ReasoningMode
from core.policy.schema import SingleSample, Policy
from core.domain.schema import ProblemDomain, Problem, BinaryProblem


class SelfDebate(ReasoningMode):
    """
    Self-debate as a reasoning strategy.
    This is distinct from multi-agent debate - the debaters and the judge are the same model, and all turns are consdiered trainable.
    """
    
    async def generate_sample_async(
        self, 
        policy: Policy, 
        domain: ProblemDomain, 
        problem: BinaryProblem, 
        expected_prior: Optional[float] = None,
    ) -> SingleSample:
        """
        The two sides debate the proposition. Private method to be called by `reasoning_rollout`; not for calling from outside.

        :param policy: The model policy that serves as both debaters and judge.
        :type policy: Policy
        :param problem: the current problem to solve.
        :type problem: BinaryProblem
        :param expected_prior: Specified prior position that the judge should be manipulated to hold (if any), defaults to None
        :type expected_prior: Optional[float], optional

        :return: dictionary containing speeches of debaters
        :rtype: dict[str, str]
        """
        problem_statement = problem.question
        option_yes, option_no = problem.options
        num_turns = int(os.environ.get("NUM_TURNS", "2"))
        silent = bool(os.environ.get("SILENT", "1"))

        if not silent:
            print(f"Problem statement: {problem_statement}")

        # Start a new history
        chat_history_prop = [
            {
                "role": "user",
                "content": initial_prompt.format(
                    problem_statement=problem_statement,
                    option_yes=option_yes,
                    option_no=option_no,
                ),
            }
        ]
        chat_history_oppo = [
            {
                "role": "user",
                "content": initial_prompt.format(
                    problem_statement=problem_statement,
                    option_yes=option_no,
                    option_no=option_yes,
                ),
            }
        ]

        if expected_prior is not None:
            prefix = construct_prompting_context(problem, expected_prior)
            chat_history_prop = prefix + chat_history_prop
            chat_history_oppo = deepcopy(prefix) + chat_history_oppo

        # History to be read by experimenters, not as prompts for LLMs.
        debate_history = ""

        # Progress bar
        with tqdm.tqdm(total=num_turns * 2, disable=silent) as pbar:

            for turn_id in range(1, num_turns + 1):
                tasks = [
                    policy.infer_single_async(chat_history_prop),
                    policy.infer_single_async(chat_history_oppo),
                ]
                results = await asyncio.gather(*tasks)

                # Debater1's argument
                argument_prop = results[0]
                argument_prop = re.sub(r"\n{2,}", "\n", argument_prop)
                chat_history_prop.append(
                    {"role": "assistant", "content": argument_prop}
                )  # Add to one's own history
                pbar.update(1)  # Move progress bar forward by 1

                # Debater2's argument
                argument_oppo = results[1]
                argument_oppo = re.sub(r"\n{2,}", "\n", argument_oppo)
                chat_history_oppo.append(
                    {"role": "assistant", "content": argument_oppo}
                )  # Add to one's own history
                pbar.update(1)

                # Disclose both speeches to the other party
                chat_history_prop.append({"role": "user", "content": argument_oppo})
                chat_history_oppo.append({"role": "user", "content": argument_prop})
                debate_history += f"Debater A ({option_yes}): {argument_prop}\n\nDebater B ({option_no}): {argument_oppo}\n\n"

        return SingleSample(
            history=[
                {
                    "role": "user",
                    "content": f"{problem_statement}\n\nOptions: {option_yes} | {option_no}",
                }
            ],
            output=debate_history.strip(),
            aux_info={"expected_prior": expected_prior},
        )
    
    async def sample_to_trajectory_async(
        self, 
        policy: Policy, 
        domain: ProblemDomain, 
        problem: BinaryProblem,
        sample: SingleSample, 
        algo: DebiasStrategy,
        expected_prior: Optional[float] = None,
        **kwargs
    ) -> ReasoningTrajectory:
        """
        Reading the two sides' debate, the judge reaches a conclusion. The method constructs the reasoning trajectory object by measuring the judge's belief at each turn.

        :param policy: The model policy that serves as both debaters and judge.
        :type policy: Policy
        :param domain: The problem domain where questions are sampled from.
        :type domain: ProblemDomain
        :param problem: The current question being debated.
        :type problem: BinaryProblem
        :param sample: The inference sample containing the two sides' debate.
        :type sample: SingleSample
        :param algo: The debias strategy to use. Only used for measuring beliefs.
        :type algo: DebiasStrategy
        :param expected_prior: Specified prior position that the judge should be manipulated to hold (if any), defaults to None
        :type expected_prior: Optional[float], optional

        :return: Trajectory object containing the judge's belief at each turn.
        :rtype: ReasoningTrajectory
        """
        use_per_traj_belief_measure = bool(eval(os.environ.get("USE_PER_TRAJ_BELIEF_MEASURE", "1")))
        
        num_turns = int(os.environ.get("NUM_TURNS", "2"))
        debate_history = [s.strip() for s in sample.output.split("\n\n") if s.strip()]
        assert len(debate_history) == num_turns * 2

        # Here we manipulate judges with a registered prior (by news headlines)
        belief_kwargs = {}
        if expected_prior is not None:
            belief_kwargs["related_news_reports"] = construct_prompting_context(
                problem, expected_prior, raw_headlines=True
            )

        # Evaluate prior, intermediate, and posterior beliefs by truncating the debate transcript
        if use_per_traj_belief_measure:
            verdicts = [None] * (num_turns + 1)
        else:
            tasks = []
            for prefix_len in range(num_turns + 1):

                # Keeping only the first prefix_len turns for each debater
                truncated_transcript = "\n\n".join(debate_history[:prefix_len * 2]).strip()

                if truncated_transcript:
                    belief_kwargs["debate_transcript"] = truncated_transcript

                tasks.append(algo.measure_belief_async(problem, policy, **belief_kwargs))

            verdicts = await asyncio.gather(*tasks)

        def get_turn_content(turn_id: int) -> str:
            if turn_id == 0:  # Prior; judge hasn't read any debate yet
                return ""

            # Get the content of the debate turn
            return "\n\n".join(debate_history[(turn_id-1) * 2: turn_id * 2]).strip()

        # Assemble the reasoning trajectory from the verdicts (aka beliefs)
        traj = ReasoningTrajectory(
            problem=problem,
            steps=[
                ReasoningStep(
                    content=get_turn_content(turn_id),
                    belief=verdicts[turn_id],
                    trainable=True,
                )
                for turn_id in range(num_turns + 1)
            ],
        )
        
        if use_per_traj_belief_measure:
            traj = await algo.measure_belief_async(problem, policy, traj)
            assert isinstance(traj, ReasoningTrajectory)

        return traj

    def trajectory_to_samples(self, traj: ReasoningTrajectory) -> list[SingleSample]:
        """
        Given a reasoning trajectory, turn this into one or more OpenAI-format dialogue (e.g. containing a single user prompt relaying the debate record).
        The final turn of the dialogue should be generated by the trained model. If there are multiple such turns, return a list of all prefixes ending at such a turn.
        Each dialogue is represented as a SingleSample, where the output field is the last turn by the model.
        """
        raise NotImplementedError
