"""
This file contains the abstract class for a reasoning mode.
It also contains utility classes, most notably a collection of reasoning trajectories.
"""

import abc
import asyncio
import dataclasses
import warnings
import json
from core.policy.schema import SingleSample, Policy
from core.domain.schema import ProblemDomain, Problem
from utils.io_utils import load_file, dump_file
from utils.async_utils import run_coroutine
from typing import Optional


@dataclasses.dataclass
class ReasoningStep:
    """One step in the model's reasoning process. For example, one turn in the debate, or one bullet point in the CoT output."""

    content: str
    belief: Optional[float]
    trainable: bool  # Whether this step is generated by the trained model, as opposed to non-trainable content (e.g. external input) or human users


@dataclasses.dataclass
class ReasoningTrajectory:
    """One session where a model performs reasoning on a certain topic. `steps` contains the reasoning steps in sequential order."""

    problem: Problem
    steps: list[ReasoningStep]


class ReasoningMode(abc.ABC):
    """A mode of multi-step model reasoning. Examples include CoT, debate, and many more."""

    def reasoning_rollout(
        self, 
        policy: Policy, 
        domain: ProblemDomain, 
        algo: "DebiasStrategy", # Avoid circular importation
        n: int = 1, 
        expected_priors: list[float | None] = [None],
        **kwargs
    ) -> list[ReasoningTrajectory]:
        """Perform reasoning on `n` problems sampled from `domain`, and returns the `n` corresponding trajectories."""
        return run_coroutine(self.reasoning_rollout_async(policy, domain, algo, n, expected_priors, **kwargs))

    async def reasoning_rollout_async(
        self, 
        policy: Policy, 
        domain: ProblemDomain, 
        algo: "DebiasStrategy",
        n: int = 1, 
        expected_priors: list[float | None] = [None],
        **kwargs
    ) -> list[ReasoningTrajectory]:
        """Perform reasoning on `n` problems sampled from `domain`, and returns the `n` corresponding trajectories."""
        if n > 1:
            rollout_tasks = [
                self.reasoning_rollout_async(policy, domain, algo, 1, **kwargs)
                for _ in range(n)
            ]
            rollout_results = await asyncio.gather(*rollout_tasks)
            return [rt for result in rollout_results for rt in result]
        
        problem = domain.sample_problems()[0]
        
        async def generate_sample_and_save_record(expected_prior: float | None) -> ReasoningTrajectory:
            """Generate a sample and save the reasoning record, given a specified prior."""
            sample = await self.generate_sample_async(policy, domain, problem, expected_prior, **kwargs)
            traj = await self.sample_to_trajectory_async(policy, domain, problem, sample, algo, expected_prior, **kwargs)
            self.save_record(problem, sample, traj, **kwargs)
            return traj
        
        # Generate reasoning trajectories in parallel for each specified prior
        return await asyncio.gather(*[generate_sample_and_save_record(expected_prior) for expected_prior in expected_priors])
    
    def generate_sample(
        self, 
        policy: Policy, 
        domain: ProblemDomain, 
        problem: Problem, 
        expected_prior: Optional[float] = None,
        **kwargs
    ) -> SingleSample:
        """Generate one or more samples from `domain` from one sampled problem."""
        return run_coroutine(self.generate_sample_async(policy, domain, problem, expected_prior, **kwargs))
    
    @abc.abstractmethod
    async def generate_sample_async(
        self, 
        policy: Policy, 
        domain: ProblemDomain, 
        problem: Problem, 
        expected_prior: Optional[float] = None,
        **kwargs
    ) -> SingleSample:
        """Generate one or more samples from `domain` from one sampled problem."""
        raise NotImplementedError
    
    def sample_to_trajectory(
        self, 
        policy: Policy, 
        domain: ProblemDomain, 
        problem: Problem, 
        sample: SingleSample, 
        algo: "DebiasStrategy",
        expected_prior: Optional[float] = None,
        **kwargs
    ) -> ReasoningTrajectory:
        """Given an inference sample, return the corresponding reasoning trajectory by performing belief measurement etc."""
        return run_coroutine(self.sample_to_trajectory_async(policy, domain, problem, sample, algo, expected_prior, **kwargs))
    
    @abc.abstractmethod
    async def sample_to_trajectory_async(
        self, 
        policy: Policy, 
        domain: ProblemDomain, 
        problem: Problem,
        sample: SingleSample, 
        algo: "DebiasStrategy",
        expected_prior: Optional[float] = None,
        **kwargs
    ) -> ReasoningTrajectory:
        """Given an inference sample, return the corresponding reasoning trajectory by performing belief measurement etc."""
        raise NotImplementedError

    @abc.abstractmethod
    def trajectory_to_samples(self, traj: ReasoningTrajectory) -> list[SingleSample]:
        """
        Given a reasoning trajectory (e.g. a debate history), turn this into one or more OpenAI-format dialogue (e.g. containing a single user prompt relaying the debate record).
        The final turn of the dialogue should be generated by the trained model. If there are multiple such turns, return a list of all prefixes ending at such a turn.
        Each dialogue is represented as a SingleSample, where the output field is the last turn by the model.
        """
        raise NotImplementedError
    
    def save_record(
        self,
        problem: Problem,
        sample: SingleSample = None,
        traj: ReasoningTrajectory = None,
        log_filename: str = "reasoning-record.json",
        **kwargs,
    ) -> None:
        """Save reasoning record. Private method to be called by `reasoning_rollout`; not for calling from outside."""
        try:
            history = load_file(log_filename)
        except (FileNotFoundError, json.JSONDecodeError):
            history = []

        history.append(
            {
                "question": dataclasses.asdict(problem),
                "trajectory": dataclasses.asdict(traj),
                "inference": dataclasses.asdict(sample),
                **kwargs,
            }
        )

        dump_file(log_filename, history)
