import torch
import dataclasses as dc
import numpy as np
import random

from typing import Literal, Any, Iterable, Final, Callable
from math import inf
from pathlib import Path

from torch.utils.data import Sampler, WeightedRandomSampler
from core.utils.iterate import selfproducts
from core.model import Tokenizer, LLM

from .utils import check_validity
from .types import Symbol, BinaryVerifiable, DetailVerifiable
from ..data import ReflSFTDataset, ReflSFTDataModule, ReflectiveSample, _OutputInfo
from ...task import ReasoningTask


type VerificationType = Symbol | Literal["skip-reject", "skip-accept"] | None


class SVReflSFTData(ReflSFTDataset):
    """SFT dataset for self-verfication"""

    def __init__(self,
        tokenizer: Tokenizer,
        task: ReasoningTask | None = None,
        tolerance: float = 0,
        decay: float = 1,
        temperature: float = 1,
        label_type: Literal["weight", "binary", "detailed"] | None = "weight",
        optional_reflection: bool = False,
        include_revision_examples: Literal[False] | int = False,  # whether to learn a revisory mapping
        max_revised_attempts: int = 0,  # max number of revisory attempts
        mode: ReflSFTDataset._Mode = "input-label",
    ):
        super().__init__(mode=mode)
        
        self.type: list[VerificationType] = []

        with self._attr_setting_mode():
            self._tokenizer = tokenizer
            self.label_type: Final = label_type
            self.optional_reflection: Final = optional_reflection
            self.tolerance = tolerance
            self.decay = decay
            self.temperature = temperature
            self.include_revision_examples = include_revision_examples
            self.max_revised_attempts = max_revised_attempts

            self._binary_verifier: Callable[[str, str], bool] | None = None
            self._detailed_verifier: Callable[[str, str], str] | None = None
            if task is not None:
                if isinstance(task, BinaryVerifiable):
                    self._binary_verifier = task.binary_verifier()
                if isinstance(task, DetailVerifiable):
                    self._detailed_verifier = task.detail_verifier(Symbol.pass_, Symbol.reject)
                if label_type == "binary" and self._binary_verifier is None:
                    raise NotImplementedError("the task does not implement `BinaryVerifiable` protocol.")
                if label_type == "detailed" and self._detailed_verifier is None:
                    raise NotImplementedError("the task does not implement `DetailVerifiable` checker.")
            elif label_type == "binary" or label_type == "detailed":
                raise ValueError(f"{label_type} reflection requires task-specific checker.")
        
            self._tok_revise = tokenizer.encode(Symbol.revise)
            self._tok_begin = tokenizer.encode(Symbol.begin)
            self._tok_reject = tokenizer.encode(Symbol.reject)
            self._tok_pass = tokenizer.encode(Symbol.pass_)
            self._tok_accept = tokenizer.encode(Symbol.accept)
            self._tok_failure = tokenizer.encode(Symbol.failure)
            if tokenizer.eos_id is None:
                self._tok_eos = None
            else:
                self._tok_eos = torch.tensor([tokenizer.eos_id], dtype=torch.int32)

    def _verification_sft_sample(
        self,
        prompt: torch.Tensor,
        branch: torch.Tensor, 
        refl_output: torch.Tensor,
        _type: Literal[Symbol.accept, Symbol.reject, "skip-accept", "skip-reject"],
    ):

        refl_begin = self._tok_begin
        return dict(
            input_ids=torch.cat((prompt, branch, refl_begin)),
            output_ids=refl_output,
            type=_type,
        )

    def _revision_sft_sample(
        self,
        prompt: torch.Tensor,
        branches: tuple[torch.Tensor, ...], 
        revision_output: torch.Tensor,
    ):
        sep = self._tok_revise
        if sep is not None:
            branches = tuple(torch.cat((b, sep)) for b in branches)
        return dict(
            input_ids=torch.cat((prompt, *branches)),
            output_ids=revision_output,
            type=Symbol.revise,
        )
        
    def _convert_non_reflective(self, tokenizer: Tokenizer, /,
                                instruction: str, output: str, **kwargs):
        return dict(
            input_ids=tokenizer.encode(instruction),
            output_ids=tokenizer.encode(output),
            type=None,
        )

    def _extract_sft_samples(self, src: ReflectiveSample) -> Iterable[dict[str, Any]]:

        if self.optional_reflection and self._tok_eos is None:
            raise ValueError("EOS token must been set for optional reflection")
        
        rejected_outputs: list[torch.Tensor] = []
        accepted_outputs: list[torch.Tensor] = []

        def refl_samples(
            src: ReflectiveSample,
            branch: _OutputInfo,
            refl_tokens: torch.Tensor,
            _type: Literal[Symbol.accept, Symbol.reject]
        ):
            if _type == Symbol.reject:
                rejected_outputs.append(branch.tokens)
            elif _type == Symbol.accept:
                accepted_outputs.append(branch.tokens)
            yield self._verification_sft_sample(src.prompt, branch.tokens, refl_tokens, _type)
            if self.optional_reflection:
                assert self._tok_eos is not None
                sample = self._verification_sft_sample(
                    src.prompt, branch.tokens, self._tok_eos, 
                    _type=("skip-accept" if _type == Symbol.accept else "skip-reject"),)
                yield sample

        if self.label_type == "weight":
            best_output = max(src.outputs, key=lambda o: o.evaluation)
            for branch in src.outputs:
                if branch.evaluation < best_output.evaluation - self.tolerance:
                    r = torch.cat((self._tok_pass, self._tok_reject))  
                    _type = Symbol.reject
                else:
                    r = torch.cat((self._tok_pass, self._tok_accept))
                    _type = Symbol.accept
                yield from refl_samples(src, branch, r, _type)
        elif self.label_type == "binary":
            verify = self._binary_verifier; assert verify is not None
            prompt = self._tokenizer.decode(src.prompt, skip_special_tokens=False)
            for branch in src.outputs:
                output = self._tokenizer.decode(branch.tokens, skip_special_tokens=False)
                # We put a "pass" before the binary label to differentiate from "no reflection"
                # In optional reflection experiments, with "pass" we know the model tries to reflect.
                if not verify(prompt, output):    
                    r = torch.cat((self._tok_pass, self._tok_reject))  
                    _type = Symbol.reject
                else:
                    r = torch.cat((self._tok_pass, self._tok_accept))
                    _type = Symbol.accept
                yield from refl_samples(src, branch, r, _type)
        elif self.label_type == "detailed":
            verify = self._detailed_verifier; assert verify is not None
            prompt = self._tokenizer.decode(src.prompt, skip_special_tokens=False)
            for branch in src.outputs:
                output_ = self._tokenizer.decode(branch.tokens, skip_special_tokens=False)
                r_text = verify(prompt, output_)
                if error := Symbol.reject in r_text:
                    _type = Symbol.reject
                else:
                    _type = Symbol.accept
                    r_text = r_text + Symbol.accept
                # ------------ DEBUG: Check textual checker ----------- #
                # if self._binary_checker is not None:
                #     berror = self._binary_checker(prompt, output_)
                #     if berror:
                #         assert error
                r = self._tokenizer.encode(r_text)
                yield from refl_samples(src, branch, r, _type)
        else:
            assert self.label_type is None

        if self.include_revision_examples and rejected_outputs and accepted_outputs:
        # requires exmples of revision, i.e. mapping incorrect attempts to a correct attempt
        # there exsits both correct and incorrect attempts
            n_revision_examples = self.include_revision_examples
            # map incorrect attempts to a correct attempt
            for n_previous_attempts in range(1, 1 + self.max_revised_attempts):
                # sample any accepted (correct) attempt
                target_attempt = random.choice(accepted_outputs)
                for previous_attempts in selfproducts(rejected_outputs, n_previous_attempts, k=n_revision_examples):
                    yield self._revision_sft_sample(src.prompt, previous_attempts, target_attempt)
            # If max_revised_attempts is reached, produce Failure
            for previous_attempts in selfproducts(rejected_outputs, self.max_revised_attempts, k=n_revision_examples):
                yield self._revision_sft_sample(src.prompt, previous_attempts, self._tok_failure)

    def print_text(self, i: int, tokenizer: Tokenizer | None = None, skip_special_tokens=False):
        if tokenizer is None:
            tokenizer = self._tokenizer
        return super().print_text(i, tokenizer, skip_special_tokens)


@dc.dataclass
class SVReflSFTDataModule(ReflSFTDataModule):

    task: ReasoningTask | None = None
    tolerance: float = 0
    decay: float = 1
    temperature: float = 1
    non_reflective_data: Literal[False, "file", "task"] | dict = False
    reflection_label: Literal["weight", "binary", "detailed"] | None = "weight"
    reflection_frequency: float = 1
    include_revision_examples: Literal[False] | int = False
    max_revised_attempts: int = 0
    use_weighted_sampler: bool = False

    def _init_dataset(self) -> ReflSFTDataset:
        return SVReflSFTData(
            self.tokenizer,
            task=self.task,
            tolerance=self.tolerance,
            decay=self.decay,
            temperature=self.temperature,
            label_type=self.reflection_label,
            optional_reflection=(self.reflection_frequency < 1),
            include_revision_examples=self.include_revision_examples,
            max_revised_attempts=self.max_revised_attempts,
            mode="input-label",
        )

    def _setup_dataset(self, dataset: ReflSFTDataset, split: str):
        super()._setup_dataset(dataset, split)

        # mix non-reflective data is needed
        if self.non_reflective_data is not False:
            if (task := self.task) is None:
                raise ValueError("The task is required for mixing reflective data.")
            if self.non_reflective_data == "file":
                path: Path | None = None
                for suffix in (".sft.json", ".sft.jsonl"):
                    _f = task.root / (split + suffix)
                    if _f.exists() and _f.is_file():
                        path = _f
                if path is None:
                    raise FileNotFoundError(f"Can not find non-reflective SFT data in {task.root}.")
                dataset.mix_non_reflecive_samples(path, self.tokenizer)
            else:
                if self.non_reflective_data == "task":
                    non_refl_data = task.get_sft_data(split)
                else:
                    assert isinstance(self.non_reflective_data, dict)
                    non_refl_data = task.get_sft_data(split, **self.non_reflective_data)
                dataset.mix_non_reflecive_samples(non_refl_data, self.tokenizer)

    def _get_sampler(self, dataset: ReflSFTDataset) -> Sampler | None:
        if self.use_weighted_sampler or self.reflection_frequency < 1:
            assert isinstance(dataset, SVReflSFTData)
            n_reject = dataset.type.count(Symbol.reject)
            n_accept = dataset.type.count(Symbol.accept)
            if self.use_weighted_sampler:
                w_reject = 2 * n_accept / (n_reject + n_accept)
                w_accept = 2 * n_reject / (n_reject + n_accept)
            else:
                w_reject = w_accept = 1.0
            w_optional = (1 - self.reflection_frequency) / self.reflection_frequency
            w_dict: dict[VerificationType, float] = {
                Symbol.reject: w_reject,
                Symbol.accept: w_accept,
                "skip-reject": w_reject * w_optional,
                "skip-accept": w_accept * w_optional,
            }
            weights = [w_dict.get(t, 1) for t in dataset.type]
            n = int((n_reject + n_accept) * (1 + w_optional))
            return WeightedRandomSampler(weights, n)
        else:
            return None


@torch.inference_mode()
def evaluate_reflection_accuracy(
    llm: LLM,
    data: SVReflSFTData,
    context_length: int,
    batch_size: int,
    sampling_args: dict = {},
    verbose: int = 1,
):
    
    from core.inference import Sampling, Session
    from pprint import pprint

    REJ = llm.preprocessor.encode(Symbol.reject)
    ACC = llm.preprocessor.encode(Symbol.accept)
    EOS = torch.tensor([llm.preprocessor.tokenizer.eos_id],
                       dtype=torch.int32, device=llm.preprocessor.device)

    def reflection_batches():
        inputs: list[torch.Tensor] = []
        errors: list[bool] = []
        n = len(data)
        indices = np.random.permutation(n)
        for i in indices:
            i = int(i)
            type_ = data.type[i]
            input = data.input_ids[i]
            if type_ == Symbol.accept:
                inputs.append(input)
                errors.append(False)
            elif type_ == Symbol.reject:
                inputs.append(input)
                errors.append(True)
            else:
                continue

            if len(inputs) >= batch_size:
                yield inputs.copy(), torch.tensor(errors, dtype=torch.bool)
                inputs.clear()
                errors.clear()

        if inputs:
            yield inputs.copy(), torch.tensor(errors, dtype=torch.bool)
    
    session = Session(context_length)
    sampling = Sampling(llm, **sampling_args)
    sampling.connect(session)

    R = str(Symbol.reject)
    A = str(Symbol.accept)
    cnts: dict[tuple[str, str] | Literal["all"], int] = {
        "all": 0,
        (R, R): 0,
        (R, A): 0,
        (A, R): 0,
        (A, A): 0,
    }

    def print_cnts():
        pos = cnts[A, R] + cnts[A, A]
        neg = cnts[R, A] + cnts[R, R]
        data = {
            "ALL": cnts["all"],
            "P": pos,
            "N": neg,
            "TP": cnts[A, A],
            "TN": cnts[R, R],
            "FP": cnts[R, A],
            "FN": cnts[A, R],
            "FP/N (e+)": cnts[R, A] / neg if neg else None,
            "FN/P (e-)": cnts[A, R] / pos if pos else None,
        }
        pprint(data)

    for inputs, errors in reflection_batches():
        n = len(inputs)
        sampling.launch(sampling.preprocess(inputs))
        context = session.context
        context.set_event("__reflection__")
        sampling.stopped = False
        sampling.infer_sequence([REJ, ACC, EOS])
        overflown = ~sampling.stopped
        session.release()
        
        rej = context.contains(REJ, since=context.when("__reflection__")) | overflown
        errors = errors.to(device=rej.device)
        cnts["all"] += n
        rr = int(torch.count_nonzero(rej & errors))
        ar = int(torch.count_nonzero(rej & (~errors)))
        ra = int(torch.count_nonzero((~rej) & errors))
        aa = int(torch.count_nonzero((~rej) & (~errors)))
        cnts[R, R] += rr
        cnts[R, A] += ra
        cnts[A, R] += ar
        cnts[A, A] += aa

        if verbose >= 2:
            print_cnts()
    
    if verbose >= 1:
        print_cnts()

    sampling.release()

    return cnts
