import os
import tqdm
import json
import time
import asyncio
import warnings
import numpy as np
from copy import deepcopy
from typing import Optional
from utils.io_utils import load_file, dump_file
from utils.templates.cot import cot_prompt
from utils.judge_manipulation_utils import construct_prompting_context
from core.algo.schema import DebiasStrategy
from core.reasoning.schema import ReasoningStep, ReasoningTrajectory, ReasoningMode
from core.policy.schema import SingleSample, Policy
from core.domain.schema import ProblemDomain, Problem, BinaryProblem


class ChainOfThought(ReasoningMode):
    """
    Chain of thought as a reasoning strategy.
    """ 
    
    async def generate_sample_async(
        self, 
        policy: Policy, 
        domain: ProblemDomain, 
        problem: BinaryProblem,
        expected_prior: Optional[float] = None,
    ) -> SingleSample:
        """
        Generate one or more samples from `domain` from one sampled problem.

        :param policy: The model policy that serves as both debaters and judge.
        :type policy: Policy
        :param domain: The problem domain where questions are sampled from.
        :type domain: ProblemDomain
        :param problem: The problem to reason on.
        :type problem: BinaryProblem
        :param expected_prior: Specified prior position that the judge should be manipulated to hold (if any), defaults to None
        :type expected_prior: Optional[float], optional

        :return: The inference sample.
        :rtype: SingleSample
        """
        # CoT prompt
        prompt = cot_prompt.format(problem_statement=problem.question)
        dialogue = [{"role": "user", "content": prompt}]

        if expected_prior is not None:
            prefix = construct_prompting_context(problem, expected_prior)
            dialogue = prefix + dialogue

        response = await policy.infer_single_async(dialogue)

        sample = SingleSample(
            history=dialogue,
            output=response,
            aux_info={"expected_prior": expected_prior},
        )
        
        if os.environ.get("SAVE_FULL_HISTORY", "0") == "1":
            dump_file(
                "cot-history.jsonl",
                sample.__dict__,
                write_mode="a",
                indent=None,
            )
        
        return sample
    
    async def sample_to_trajectory_async(
        self, 
        policy: Policy, 
        domain: ProblemDomain, 
        problem: BinaryProblem,
        sample: SingleSample, 
        algo: DebiasStrategy,
        expected_prior: Optional[float] = None,
    ) -> ReasoningTrajectory:
        """
        Given the CoT, construct the reasoning trajectory object. Private method to be called by `reasoning_rollout`; not for calling from outside.

        :param policy: The model policy that serves as both debaters and judge.
        :type policy: Policy
        :param domain: The problem domain where questions are sampled from.
        :type domain: ProblemDomain
        :param problem: The problem to reason on.
        :type problem: BinaryProblem
        :param sample: The inference sample containing the CoT.
        :type sample: SingleSample
        :param algo: The debias strategy to use. Only used for measuring beliefs.
        :type algo: DebiasStrategy
        :param expected_prior: Specified prior position that the judge should be manipulated to hold (if any), defaults to None
        :type expected_prior: Optional[float], optional

        :return: The reasoning trajectory object.
        :rtype: ReasoningTrajectory
        """
        use_per_traj_belief_measure = bool(eval(os.environ.get("USE_PER_TRAJ_BELIEF_MEASURE", "1")))
        
        problem_statement = problem.question
        cot = sample.output.split("\n\n")
        cot = [step.strip() for step in cot if step.strip()]
        num_turns = len(cot)

        def get_turn_content(turn_id: int) -> str:
            if turn_id == 0:
                return ""
            else:
                return cot[turn_id - 1]

        # Here we manipulate judges with a registered prior (by news headlines)
        belief_kwargs = {}
        if expected_prior is not None:
            belief_kwargs["related_news_reports"] = construct_prompting_context(
                problem, expected_prior, raw_headlines=True
            )

        async def get_turn_belief(turn_id: int) -> float | None:
            if use_per_traj_belief_measure:
                return None  # Will be measured later
            
            truncated_cot = cot[:turn_id]
            kgs = deepcopy(belief_kwargs)
            kgs["reasoning_steps"] = truncated_cot
            if not truncated_cot:
                del kgs["reasoning_steps"]

            belief = await algo.measure_belief_async(problem, policy, **kgs)
            return belief

        beliefs = await asyncio.gather(
            *[get_turn_belief(turn_id) for turn_id in range(num_turns + 1)]
        )

        traj = ReasoningTrajectory(
            problem=problem,
            steps=[
                ReasoningStep(
                    content=get_turn_content(turn_id),
                    belief=beliefs[turn_id],
                    trainable=True,
                )
                for turn_id in range(num_turns + 1)
            ],
        )
        
        if use_per_traj_belief_measure:
            traj = await algo.measure_belief_async(problem, policy, traj)
            assert isinstance(traj, ReasoningTrajectory)

        return traj

    def trajectory_to_samples(self, traj: ReasoningTrajectory) -> list[SingleSample]:
        """
        Given a reasoning trajectory, turn this into one or more OpenAI-format dialogue (e.g. containing a single user prompt relaying the debate record).
        The final turn of the dialogue should be generated by the trained model. If there are multiple such turns, return a list of all prefixes ending at such a turn.
        Each dialogue is represented as a SingleSample, where the output field is the last turn by the model.
        """
        raise NotImplementedError
