import dataclasses as dc
from torch.utils.data import DataLoader
import os
import torch
from functools import partial
from pathlib import Path
from typing import Any, Callable, Literal, ClassVar, Iterable

from litgpt import Tokenizer
from litgpt.prompts import PromptStyle, Default
from litgpt.data import JSON, DataModule, SFTDataset as _LitGPTSFTDataSet, get_sft_collate_fn
from core.tokenization import Vocabulary
from .formulation import CoT
from .style import Style
from .rm import RewardModel
from .evaluator import Evaluator, AverageReturn


type Split = Literal['train', 'test', 'val'] | str
type _FormatStr = Literal[
    '.txt',
    '.json', '.jsonl',
    '.cot.json', '.cot.jsonl',
    '.sft.json', '.sft.jsonl',
] | str
type NestedList[T] = list[T] | list[NestedList[T]]


class ReasoningTask:

    # configuration
    default_root: ClassVar[Path | None] = None

    style: Style = Style()
    
    data_options: dict = {}
    
    @property
    def root(self) -> Path:
        """The root directory of the reasoning data."""
        try:
            root = self.__root
        except AttributeError:
            root = None
        
        if root is None:
            if self.default_root:
                return self.default_root
            else:
                raise ValueError(f"The root of {self.__class__.__name__} is not specified, and there is no default root.")
        else:
            return root

    @root.setter
    def root(self, value: str | Path | None):
        self.__root = Path(value) if isinstance(value, str) else value

    def get_instances(self, split: Split) -> list[CoT]:
        """
        Obtain the list of CoT (Chain-of-Thought) instances from a certrain split.

        Args:
            split: the tag of subset split, e.g. "train", "test", or "val".
        
        Returns:
            instances (list): the list of instances.

        Raises:
            KeyError: `split` is not a supported subset.
        """

        raise KeyError(
            f"Subset \"{split}\" for {self.__class__.__name__} is not supported. "
        )
    
    def sample(self, split: Split, require_thought: bool = False) -> CoT:
        """
        Randomly draw an instance of CoT from the distribution of `split`.
        """
        raise NotImplementedError

    def samples(self, split: Split, n: int, require_thought: bool = False) -> list[CoT]:
        """
        Randomly draw instances of CoT from the distribution of `split`.
        """
        return [self.sample(split, require_thought) for i in range(n)]

    def get_sft_data(self, split: Split, verbose: int = 1, **options):
        cots = self.get_instances(split)
        style = self.style
        options = self.data_options | options
        sft_data: list[dict[str, str]] = []
        for i, cot in enumerate(cots):
            for sample in style._sftdata(cot, **options):
                sft_data.append(sample)
            if verbose >= 1:
                print(f"{i+1}/{len(cots)} CoT has been processed into {len(sft_data)} {split} samples.",
                      end="\r")
        if verbose >= 1:
            print("")
        return sft_data
            
    def export_data(self, data: list[CoT], path: str | Path, **options):        
        if not isinstance(path, Path):
            path = Path(path)
        options = self.data_options | options
        self.style.save_instances(data, path, **options)
        print(f"Successully exported data to {path}.")

    def prepare_data(
        self,
        files: dict[Split, Iterable[_FormatStr]],
        force_write: bool = False,
        path: Path | str | None = None,
        **options,
    ):
        if path is None:
            path = self.root
        elif isinstance(path, str):
            path = Path(path)
        
        path.mkdir(exist_ok=True, parents=True)

        skipped: list[Path] = []
        for split, fmts  in files.items():
            cots = self.get_instances(split)
            for fmt in fmts:
                f = (path / split).with_suffix(fmt)
                if not force_write and f.exists():
                    skipped.append(f)
                    continue
                self.export_data(cots, f, **options)

        if skipped:
            print("To save computational time, the preparation of these existing files have been ignored:")
            for f in skipped:
                print(f'- \"{f}\"')
            print("Delete the existing files if you hope to update them.")

    def reward_model(
        self,
        supervision: Literal['outcome', 'process', 'success'],
    ) -> RewardModel:
        raise NotImplementedError
    
    def get_ref_dict(self, cot: CoT) -> dict[str, Any]:
        """Convert a CoT to a dictionary with information referred by reward models / evaluators."""
        return cot.as_dict()

    def special_tokens(self):
        return self.style.special_tokens()
    
    def evaluator(self, **options) -> Evaluator:
        return AverageReturn(self.reward_model('success'), **options)

    def data_module(self, usage: Literal["sft", "pretrain"],
                    seed: int = 42, from_file: bool = False, **options):
        if from_file:
            return ReasoningJSON(
                self.root,
                prompt_style=Default(),
                file_suffixes=('.sft.json', '.sft.jsonl'),
                seed=seed,
                mask_prompt=(usage=="sft"),
                include_eos=False,
            )
        else:
            return TaskDataModule(
                self,
                data_options=options,
                seed=seed,
                mask_prompt=(usage=="sft"),
                include_eos=False,
            )


NAMED_TASKS: dict[str, Callable[..., ReasoningTask]] = {}


def register_task(name: str, cls: Callable[..., ReasoningTask]):
    NAMED_TASKS[name] = cls


def task_from_name(name: str, *args, **kwargs):
    try:
        cls = NAMED_TASKS[name]
    except KeyError:
        available = ", ".join(f"\"{name}\"" for name in available_task_names())
        raise ValueError(f"\"{name}\" is not a available name of task."
                         f" The available tasks are: {available}")
    return cls(*args, **kwargs)


def is_available_task_name(name: str):
    return name in NAMED_TASKS

def available_task_names():
    return tuple(NAMED_TASKS.keys())


class SFTDataset(_LitGPTSFTDataSet):
    """SFT Dataset adapted from litgpt."""

    def __init__(
        self,
        data: list[dict[str, str]],
        tokenizer: Tokenizer,
        prompt_style: str | PromptStyle,
        max_seq_length: int = -1,
        mask_prompt: bool = True,
        ignore_index: int = -100,
        transform: Callable[[Any], Any] | None = None,
        include_eos: bool = False,
    ) -> None:
        super().__init__(data, tokenizer, prompt_style, max_seq_length,
                         mask_prompt, ignore_index, transform)
        
        self._include_eos = include_eos

    def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
        example = self.data[idx]
        if self.transform is not None:
            example = self.transform(example)
        prompt = self.prompt_style.apply(prompt=example["instruction"], **example)
        encoded_prompt = self.tokenizer.encode(prompt, max_length=self.max_seq_length)
        encoded_response = self.tokenizer.encode(
            example["output"],
            bos=False,
            eos=self._include_eos,
            max_length=self.max_seq_length
        )
        encoded_prompt_and_response = torch.cat((encoded_prompt, encoded_response)).type(torch.int64)
        if self.max_seq_length > 0: # do not slice off last token when self.max_seq_length = -1
            encoded_prompt_and_response = encoded_prompt_and_response[: self.max_seq_length]

        # The labels are the full prompt with response, but with the prompt masked out
        labels = encoded_prompt_and_response.clone()
        if self.mask_prompt:
            labels[: len(encoded_prompt)] = self.ignore_index

        return {"input_ids": encoded_prompt_and_response, "labels": labels}


class TaskDataModule(DataModule):
    """
    Converting CoTs (Loaded from JSON or JSONL data) to SFT data.
    """

    ignore_index: int
    seed: int
    num_workers: int
    tokenizer: Tokenizer | None = None
    batch_size: int = 1
    max_seq_length: int = -1
    train_dataset: SFTDataset | None
    val_dataset: SFTDataset | None
    transform: Callable[[Any], dict[str, str]] | None = None
    include_eos: bool = False

    def __init__(
        self,
        task: ReasoningTask,
        data_options: dict = {},
        prompt_style: PromptStyle | None = None,
        mask_prompt: bool = False,
        transform: Callable[[Any], dict[str, str]] | None = None,
        include_eos: bool = False,
        verbose: int = 1,
        seed: int = 42,
        num_workers: int = 4,
        ignore_index: int = -100,
    ):
        self.task = task
        self.data_options = data_options
        self.prompt_style = prompt_style or Default()
        self.mask_prompt = mask_prompt
        self.transform = transform
        self.include_eos = include_eos
        self.verbose = verbose
        self.seed = seed
        self.num_workers = num_workers
        self.ignore_index = ignore_index
    
    def connect(
        self, tokenizer: Tokenizer | None = None, batch_size: int = 1, max_seq_length: int | None = None
    ) -> None:
        self.tokenizer = tokenizer
        self.batch_size = batch_size
        self.max_seq_length = -1 if max_seq_length is None else max_seq_length

    def setup(self, stage: str = "") -> None:
        if self.tokenizer is None:
            raise RuntimeError("The data module has not been connected to a tokenizer.")
        train_data = self.task.get_sft_data("train", self.verbose, **self.data_options)
        val_data = self.task.get_sft_data("val", self.verbose, **self.data_options)
        self.train_dataset = SFTDataset(
            data=train_data,
            tokenizer=self.tokenizer,
            prompt_style=self.prompt_style,
            max_seq_length=self.max_seq_length,
            mask_prompt=self.mask_prompt,
            ignore_index=self.ignore_index,
            transform=self.transform,
            include_eos=self.include_eos,
        )
        self.val_dataset = SFTDataset(
            data=val_data,
            tokenizer=self.tokenizer,
            prompt_style=self.prompt_style,
            max_seq_length=self.max_seq_length,
            mask_prompt=self.mask_prompt,
            ignore_index=self.ignore_index,
            transform=self.transform,
            include_eos=self.include_eos,
        )

    def train_dataloader(self) -> DataLoader:
        if self.train_dataset is None:
            raise RuntimeError("train dataset has no been set up")
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            generator=torch.Generator().manual_seed(self.seed),
            num_workers=self.num_workers,
            collate_fn=get_sft_collate_fn(max_seq_length=self.max_seq_length, ignore_index=self.ignore_index),
        )

    def val_dataloader(self) -> DataLoader:
        if self.val_dataset is None:
            raise RuntimeError("train dataset has no been set up")
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            collate_fn=get_sft_collate_fn(max_seq_length=self.max_seq_length, ignore_index=self.ignore_index),
        )


@dc.dataclass
class ReasoningJSON(JSON):
    """
    Extended from litgpt.data.JSON, which allows customization of more settings.
    """

    file_suffixes: tuple[str, ...] = ('.json', '.jsonl')
    transform: Callable[[Any], dict[str, str]] | None = dc.field(default=None, repr=False)
    include_eos: bool = False
    
    def find_split(self, split_name: str) -> Path | None:
        for suffix in self.file_suffixes:
            if (file := self.json_path / f"{split_name}{suffix}").is_file():
                return file
        return None
    
    def setup(self, stage: str = "") -> None:
        assert self.tokenizer is not None, "call `connect` first!"
        
        train_data, test_data = self.get_splits()
        self.train_dataset = SFTDataset(
            data=train_data,
            tokenizer=self.tokenizer,
            prompt_style=self.prompt_style,
            max_seq_length=self.max_seq_length,
            mask_prompt=self.mask_prompt,
            ignore_index=self.ignore_index,
            transform=self.transform,
            include_eos=self.include_eos,
        )
        self.test_dataset = SFTDataset(
            data=test_data,
            tokenizer=self.tokenizer,
            prompt_style=self.prompt_style,
            max_seq_length=self.max_seq_length,
            mask_prompt=self.mask_prompt,
            ignore_index=self.ignore_index,
            transform=self.transform,
            include_eos=self.include_eos,
        )
