from typing import cast, Literal, Never, Sequence, Iterable, Any, Callable
from dataclasses import dataclass, field
import torch
import json
import random as rd
import abc
from torch.utils.data import DataLoader, Sampler
from pathlib import Path
from core.inference import Context, Session
from core.inference.sampling import Sampling
from core.utils import TokenBuffer
from core.model import Tokenizer, Preprocessor
from core.utils.th import NamedDataset
from litgpt.data.base import get_sft_collate_fn, DataModule
from litgpt.prompts import PromptStyle, Default
from ..formulation import CoT
from ..reasoners.generative import MTPReasoner
from ..evaluator import CumulatedReward
from ..task import ReasoningTask


def _serialize_tokens(x: torch.Tensor) -> str:
   if not (x.dtype == torch.int32 and x.ndim == 1):
       raise ValueError("Only accept 1-dimensional int32 tensor.")
   return ' '.join(map(str, x.tolist()))


def _unserialize_tokens(b: str, device=None) -> torch.Tensor:
    return torch.tensor(list(map(int, b.split())), dtype=torch.int32, device=device)


def input_labels(input_ids: torch.Tensor, output_ids: torch.Tensor,
                 ignore_input: bool = True, ignore_idx: int = -100) -> dict[str, torch.Tensor]:

    assert input_ids.ndim == 1 and output_ids.ndim == 1
    input_len = len(input_ids)
    input_ids = torch.cat((input_ids, output_ids))
    labels = input_ids.type(torch.int64)  # labels must be int64 to compute cross-entropy loss
    if ignore_input:
        if labels is input_ids:
            labels = labels.clone()  # ensure that labels do not share memory with inputs.
        labels[:input_len] = ignore_idx
    return {"input_ids": input_ids, "labels": labels}


@dataclass(eq=False, slots=True)
class _OutputInfo:

    tokens: torch.Tensor
    evaluation: float
    weight: float | int

    def __post_init__(self):
        assert self.tokens.ndim == 1

    def duplicated(self, other: torch.Tensor):
        return torch.equal(self.tokens, other)
    
    def as_dict(self):
        return {
            "tokens": _serialize_tokens(self.tokens),
            "evaluation": self.evaluation,
            "weight": self.weight,
        }

    @staticmethod
    def from_dict(d: dict, device=None):
        return _OutputInfo(
            _unserialize_tokens(d['tokens'], device),
            d['evaluation'],
            d['weight'],
        )


@dataclass(eq=False, slots=True)
class ReflectiveSample:

    prompt: torch.Tensor
    outputs: list[_OutputInfo]

    def present(self, decoder: Tokenizer | Preprocessor, skip_special_tokens: bool = True, indent=4):

        def decode(tokens: torch.Tensor):
            return decoder.decode(tokens, skip_special_tokens=skip_special_tokens)

        return json.dumps({
            "prompt": decode(self.prompt),
            "outputs": [
                {"tokens": decode(o.tokens), "evaluation": o.evaluation, "weight": o.weight}
                for o in self.outputs
            ],
        }, indent=indent)
    
    @property
    def average_evaluation(self):
        return sum(o.evaluation * o.weight for o in self.outputs) / sum(o.weight for o in self.outputs)
    
    def as_dict(self):
        return {
            "prompt": _serialize_tokens(self.prompt),
            "outputs": list(map(_OutputInfo.as_dict, self.outputs))
        }
    
    @staticmethod
    def from_dict(d: dict, device=None):
        return ReflectiveSample(
            _unserialize_tokens(d['prompt'], device),
            list(map(_OutputInfo.from_dict, d['outputs']))
        )


class ReflSFTDataset(NamedDataset[list]):

    _Mode = Literal["input-output", "input-label"]
    mode: _Mode

    def __init__(self, mode: _Mode = "input-label"):
        super().__init__()
        
        self.input_ids: list[torch.Tensor] = []
        self.output_ids: list[torch.Tensor] = []

        with self._attr_setting_mode():
            self.mode = mode

    def __getitem__(self, idx: int) -> dict[str, Any]:
        item = super().__getitem__(idx)
        if self.mode == "input-label":
            input_ids = item.pop("input_ids")
            output_ids = item.pop("output_ids")
            assert isinstance(input_ids, torch.Tensor) and isinstance(output_ids, torch.Tensor)
            item.update(input_labels(input_ids, output_ids))
        return item

    def load_samples(self, samples: Iterable[ReflectiveSample]):
        if isinstance(samples, Sequence):
            n = len(samples)
        else:
            n = '?'
        for i, sample in enumerate(samples):
            for sft_sample in self._extract_sft_samples(sample):
                for k, list_ in self.__data_dict__.items():
                    list_.append(sft_sample[k])
                m = len(self)
                print(f"{i}/{n} reflection entries have been processed."
                      f"{m} SFT samples have been constructed.",
                      end='\r')
        print()
    
    def _convert_non_reflective(self, tokenizer: Tokenizer, /, 
                                instruction: str, output: str, **kwargs):
        return {
            "input_ids": tokenizer.encode(instruction),
            "output_ids": tokenizer.encode(output),
        }

    def mix_non_reflecive_samples(self, data: Path | str | list[dict[str, str]], tokenizer: Tokenizer):
        if isinstance(data, (Path, str)):
            with open(data, 'rt') as f:
                data = json.load(f)
        assert isinstance(data, list)
        cnt = 0
        n = len(data)
        for raw_sample in data:
            sft_sample = self._convert_non_reflective(tokenizer, **raw_sample)
            for k, list_ in self.__data_dict__.items():
                    list_.append(sft_sample[k])
            cnt += 1
            print(f"Mixing non-reflective SFT samples: {cnt}/{n} processed.", end='\r')
        print()

    def load_file(self, path: str | Path):
        with open(path, 'rt') as f:
            samples = json.load(f)
        assert isinstance(samples, list)
        samples = list(map(ReflectiveSample.from_dict, samples))
        return self.load_samples(samples)
    
    def print_text(self, i: int, tokenizer: Tokenizer, skip_special_tokens=True):
        item = super().__getitem__(i)
        input_ids = item.pop("input_ids")
        output_ids = item.pop("output_ids")
        print("Input>", tokenizer.decode(input_ids, skip_special_tokens))
        print("Output>", tokenizer.decode(output_ids, skip_special_tokens))
        for k, v in item.items():
            print(k + '>', v)
    
    @abc.abstractmethod
    def _extract_sft_samples(self, src: ReflectiveSample) -> Iterable[dict[str, Any]]:
        raise NotImplementedError


class ReflDataCollector:

    __ref_dicts: list[dict]

    @dataclass
    class EvalArgs:
        reward: Literal["process", "outcome", "success"] = "outcome"
        discount: float = 1
        rollout_length: int | None = None
        """`0` for instant evaluation. `None` for infinite rollout."""

        def __post_init__(self):
            assert self.rollout_length is None or self.rollout_length >= 0
            assert 0 <= self.discount <= 1

    def __init__(
        self,
        task: ReasoningTask,
        reasoner: MTPReasoner,
        n_branches: int,
        propose_temperature: float = 1.0,
        solve_temperature: float = 1.0,
        rollout_temperature: float = 1.0,
        evaluation: EvalArgs = EvalArgs(),
        verbose: int = 1,
        bg_device: torch.device | str | None = "cpu",
    ) -> None:
        
        self.task = task
        self.reasoner = reasoner
        self.n_branches = n_branches
        self.verbose = verbose
        self.propose_temperature = propose_temperature
        self.solve_temperature = solve_temperature
        self.rollout_temperature = rollout_temperature
        self.bg_device = torch.device(bg_device) if bg_device is not None else bg_device
        self.evaluation = evaluation
        
        if not isinstance(reasoner.inference, Sampling):
            raise TypeError("The reasoner must uses sampling inference.")

        # perpare for computing cumulated reward
        self.reasoner.ref = self._ref
        self._cumr = CumulatedReward(task.reward_model(evaluation.reward), evaluation.discount)
        self._cumr.attach(self.reasoner)

    @property
    def sampling(self) -> Sampling:
        return cast(Sampling, self.reasoner.inference)

    def _ref(self, idx: tuple[int, ...]):
        i = idx[0]
        return self.__ref_dicts[i]

    def _castsess(self, sess: Session):
        return cast(Session[Literal["__gen__"], Never, Never, Never], sess)
    
    def _clear_evaluation(self):
        self._cumr.clear_()

    def _evalute(self) -> torch.Tensor:
        evaluation = self.evaluation
        reasoner = self.reasoner
        rollout_length = evaluation.rollout_length
        done = reasoner._transit()
        if done or rollout_length == 0:
            r = self._cumr.get_results()
        else:
            # perform rollouts to evaluate the branches
            assert not done
            self.sampling.temperature = self.rollout_temperature
            if rollout_length is not None:  # If the rollout length is limited
                old_max_step = reasoner.max_step
                reasoner.max_step = min(reasoner._n_step + rollout_length, old_max_step)
            _, _ = self.reasoner._solve()
            if rollout_length is not None:
                reasoner.max_step = old_max_step
            r = self._cumr.get_results()
        return r

    def _solve_and_collect(self, data: list[ReflectiveSample]):
        sess = self._castsess(self.reasoner.session)
        reasoner = self.reasoner
        ctx = sess.context
        ctx.set_event("__gen__")
        r = ctx.make_tensor(
            (self.n_branches,),
            dtype=self._cumr.reward_model.dtype,
        )
        
        outputs: list[TokenBuffer] = []
        del ctx

        # fisrst, collect branches
        local_state = reasoner._get_local_state()
        for j in range(self.n_branches):
            with sess.tryfork(True, True, True, True, True,
                              state_device=self.bg_device, restore_model='kv'):
                self.sampling.temperature = self.propose_temperature
                self._clear_evaluation()
                reasoner._generate()
                output = sess.context.eseg("__gen__", None)
                outputs.append(output)
                r[..., j] = self._evalute()
            reasoner._load_local_state(**local_state)

        terminated = reasoner.terminated
        for i, prompt in sess.context.enumerate():
            prompt = prompt.cpu()
            if terminated[i]:
                continue
            outputs_: list[_OutputInfo] = []

            for j in range(self.n_branches):
                output = outputs[j].tokens_at(*i).cpu()
                r_ij = float(r[i + (j,)])
                duplicated = False
                for o in outputs_:
                    if o.duplicated(output):
                        o.evaluation = (o.evaluation * o.weight + r_ij) / (o.weight + 1)
                        o.weight = o.weight + 1
                        duplicated = True
                        break
                if not duplicated:
                    outputs_.append(o := _OutputInfo(output, r_ij, 1))

            sample = ReflectiveSample(
                prompt=prompt.cpu(),
                outputs=outputs_,
            )
            data.append(sample)
            
            if self.verbose >= 2:
                print("Sample No. %d:" % len(data),
                      sample.present(reasoner.llm.preprocessor, False))
            elif self.verbose >= 1:
                print("%d samples for have been collected for reflection traning." % len(data),
                      end='\r')
        
        # collect in next state
        self.sampling.temperature = self.solve_temperature
        
        reasoner._generate()
        reasoner._transit()
        reasoner._n_step += 1
        if reasoner._done():
            reasoner.inference.submit()
            thought, outcome = reasoner._extract()
            return thought, outcome
        else:
            return self._solve_and_collect(data)

    def _collect_batch(self, batch: Sequence[CoT], data: list[ReflectiveSample]):
        
        style = self.task.style
        text_inputs = [style.apply_input(**cot.as_dict()) for cot in batch]
        self.__ref_dicts = [self.task.get_ref_dict(cot) for cot in batch]

        input = self.reasoner.preprocess(text_inputs)
        self.reasoner.before_reasoning(input)
        thought, outcome = self._solve_and_collect(data)
        self.reasoner.after_reasoning(thought, outcome)

        del self.__ref_dicts
    
    @torch.inference_mode()
    def __call__(self,
                 batch_size: int,
                 instances: Sequence[CoT] | None = None,
                 max_n_sample: int | None = None,
                 data: list[ReflectiveSample] | None = None):
        
        if data is None:
            data = []
        
        if instances is None:
            instances = self.task.get_instances('train')
        
        if max_n_sample:
            instances = list(instances)
            rd.shuffle(instances)

        for i in range(0, len(instances), batch_size):
            self._collect_batch(instances[i: i + batch_size], data)
            if max_n_sample and len(data) > max_n_sample:
                break
        
        return data

    @staticmethod
    def save_data(path: Path | str, data: Sequence[ReflectiveSample], append: bool = True):
        path = Path(path)
        append &= path.exists()
        if path.suffix == ".jsonl":
            mode = 'at' if append else 'wt'
            with open(path, mode) as f:
                for item in data:
                    json.dump(item.as_dict(), f, indent=4)
        else:
            data_ = list(map(ReflectiveSample.as_dict, data))
            if append:
                old_data = json.load(open(path, 'rt'))
                assert isinstance(old_data, list)
                data_ = old_data + data_
            with open(path, 'wt') as f:
                json.dump(data_, f, indent=4)



@dataclass
class ReflSFTDataModule(DataModule, abc.ABC):

    data_path: Path | str
    seed: int = 42
    num_workers: int = 4
    max_seq_length: int = field(init=False, default=-1)
    batch_size: int = field(init=False, default=1)
    tokenizer: Tokenizer = field(init=False, repr=False)
    train_data: ReflSFTDataset = field(init=False, repr=False)
    val_data: ReflSFTDataset = field(init=False, repr=False)
    prompt_style: PromptStyle = field(default_factory=Default)

    @abc.abstractmethod
    def _init_dataset(self) -> ReflSFTDataset: ... 

    def __post_init__(self):
        if not isinstance(self.data_path, Path):
            self.data_path = Path(self.data_path)
        if not self.data_path.is_dir():
            raise FileNotFoundError(f"\"{self.data_path}\" is not a valid directory.")    

    def connect(
        self, tokenizer: Tokenizer, batch_size: int = 1, max_seq_length: int | None = None
    ) -> None:
        self.tokenizer = tokenizer
        self.batch_size = batch_size
        self.max_seq_length = -1 if max_seq_length is None else max_seq_length
    
    def _setup_dataset(self, dataset: ReflSFTDataset, split: str):
        import os
        data_path = cast(Path, self.data_path)
        for file_name in os.listdir(data_path):
            file_path = data_path / file_name
            if file_path.name.startswith(split) and file_path.suffix in (".json", ".jsonl"):
                dataset.load_file(file_path)
    
    def setup(self, stage: str = "") -> None:
        self.train_data = self._init_dataset()
        self.val_data = self._init_dataset()
        self._setup_dataset(self.train_data, "train")
        self._setup_dataset(self.val_data, "val")
        if len(self.train_data) == 0:
            raise RuntimeError(f"Empty training set.")
        if len(self.val_data) == 0:
            raise RuntimeError(f"Empty validation set.")

    def train_dataloader(self) -> DataLoader:
        sampler = self._get_sampler(self.train_data)
        return DataLoader(
            self.train_data,
            batch_size=self.batch_size,
            shuffle=(sampler is None),
            sampler=sampler,
            generator=torch.Generator().manual_seed(self.seed),
            num_workers=self.num_workers,
            collate_fn=get_sft_collate_fn(max_seq_length=self.max_seq_length),
        )

    def val_dataloader(self) -> DataLoader:
        return DataLoader(
            self.val_data,
            batch_size=self.batch_size,
            shuffle=False,
            sampler=self._get_sampler(self.val_data),
            num_workers=self.num_workers,
            collate_fn=get_sft_collate_fn(max_seq_length=self.max_seq_length),
        )

    def _get_sampler(self, dataset: ReflSFTDataset) -> Sampler | None:
        return None
