import torch
import dataclasses as dc
import json
from typing import Literal, Callable
from pathlib import Path

from core.tokenization import Vocabulary
from ..base import ReflectionBase
from ...reasoners.generative import GenerativeReasoner, MTPReasoner
from ...rm import RewardModel, RewardWrapper
from .types import Symbol, PathLike, BinaryVerifiable, RejectRatioEvaluated
from .impl import convert_reasoner_impl


@dc.dataclass
class SelfVerify(ReflectionBase):

    budget: int | None = None
    context_length: int | None = None
    reject_mode: Literal["retry", "revise"] = "retry"
    max_retry: int | None = None
    external_verifier: Literal["vf", "oracle", "random"] | None = None
    revise_temperature: float | None = None
    reflect_temperature: float = 0
    enable_statistics: bool = False
    vf_path: PathLike | None = None
    vf_tolerance: float | None = None
    random_reject_ratio: float | PathLike = 0  # used in random reflector

    required_vocab = Vocabulary(**Symbol.__members__)

    def _get_random_reject_ratio_map(self, checkpoint_dir: PathLike):
        """
        Acquire the mapping of reflective rejection ratio, using evlaution results stored in
        `checkpoint_dir/random_reject_ratio`. This is used for ablation experiments.
        """
        if self.task is None or not isinstance(self.task, RejectRatioEvaluated):
            raise ValueError("To infer reject ratio from evaluation results, please provide a task "
                             "that implements `RejectRatioEvaluated`")
        assert isinstance(self.random_reject_ratio, (str, Path))
        file_path = Path(checkpoint_dir) / (self.random_reject_ratio)
        try:
            with open(file_path, "rt") as f:
                eval_results = json.load(f)
        except FileNotFoundError:
            eval_results = None
        return self.task.reject_ratio_map(eval_results)

    def convert_reasoner[T: type[GenerativeReasoner]](self, cls: T) -> T:
        if not issubclass(cls, MTPReasoner):
            raise TypeError("The original reasoner class should have a recursive interface.")
        
        oracle = None
        if self.task is not None and isinstance(self.task, BinaryVerifiable):
            oracle = self.task.binary_verifier()
        if self.external_verifier == "oracle" and oracle is None:
            raise ValueError("Missing oracle verifier. Please provide `self.task` with "
                             "a task that implements `BinaryVefifiable`")
        
        if self.external_verifier == "random":
            if isinstance(self.random_reject_ratio, (str, Path)):
                random_reject_ratio = None
                reject_ratio_map_getter = self._get_random_reject_ratio_map
            elif isinstance(self.random_reject_ratio, (float, int)):
                random_reject_ratio = self.random_reject_ratio
                reject_ratio_map_getter = None
            else:
                raise TypeError
        else:
            random_reject_ratio = None
            reject_ratio_map_getter = None

        return convert_reasoner_impl(
            cls,
            self.budget,
            self.context_length,
            self.reject_mode,
            self.max_retry,
            self.external_verifier,
            self.revise_temperature,
            self.reflect_temperature,
            self.enable_statistics,
            oracle,
            self.vf_path,
            self.vf_tolerance,
            random_reject_ratio,
            reject_ratio_map_getter,
        )


# =================================
# Wrap a reward model with reflection
# ================================


class SVRewardWrapper(RewardWrapper):

    def __init__(self, base: RewardModel, reject_coef: float = 0):
        super().__init__(base)
        self.reject_coef = reject_coef

    def _parse_process(self, prompt: str, output: str):
        i = output.find(Symbol.begin)
        if i < 0:
            return self.base._parse_process(prompt, output)
        else:
            output, reflection = output[:i].strip(), output[i:].strip()
            items = self.base._parse_process(prompt, output)
            items["_base_output"] = output
            items["_reflection"] = reflection
            return items

    def process_reward(self, prompt: str, output: str, **references) -> float:
        if (reflection := references.get("_reflection")) is None:
            return self.base.process_reward(prompt, output, **references)
        else:
            output = references.get("_base_output", output)
            assert isinstance(reflection, str) and isinstance(output, str)
            raw = self.base.process_reward(prompt, output, **references)
            if Symbol.reject in reflection:
                return self.reject_coef * raw
            else:
                return raw
    
    def abort_process(self, prompt: str, output: str, **references) -> bool:
        if (reflection := references.get("_reflection")) is None:
            return self.base.abort_process(prompt, output, **references)
        elif Symbol.reject in reflection:
            return False
        else:
            output = references.get("_base_output", output)
            assert isinstance(reflection, str) and isinstance(output, str)
            return self.base.abort_process(prompt, output, **references)

