# Copyright 2023 PKU-Alignment Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from __future__ import annotations

from typing import Callable, Hashable
from typing_extensions import TypedDict  # Python 3.10+

import torch
from torch.utils.data import Dataset, Subset, ConcatDataset
import numpy as np

from safe_rlhf.datasets.base import CollatorBase, RawSample, TokenizedDataset, RawDataset
from safe_rlhf.datasets.utils import format_prompt, left_padding
from fractions import Fraction


__all__ = [
    'PromptOnlyDataset',
    'PromptOnlyCollator',
    'PromptOnlySample',
    'PromptOnlyBatch',
    'PromptOnlyRawDataset',
]


class PromptOnlySample(TypedDict, total=True):
    input_ids: torch.LongTensor  # size = (L,)


class PromptOnlyBatch(TypedDict, total=True):
    input_ids: torch.LongTensor  # size = (B, L)
    attention_mask: torch.BoolTensor  # size = (B, L)


class PromptOnlyDataset(TokenizedDataset):
    def preprocess(self, raw_sample: RawSample) -> PromptOnlySample:
        prompt = format_prompt(input=raw_sample['input'], eos_token=self.tokenizer.eos_token)
        input_ids = self.tokenize(prompt)
        return {
            'input_ids': input_ids,  # size = (L,)
        }

    def get_collator(self) -> Callable[[list[dict[str, torch.Tensor]]], dict[str, torch.Tensor]]:
        return PromptOnlyCollator(self.tokenizer.pad_token_id)

    def _merge_raw_datasets(self, seed: int | None = None) -> Dataset[RawSample]:
        """Merge multiple raw datasets into one dataset and remove duplicates."""

        def to_hashable(raw_sample: RawSample) -> Hashable:
            input = raw_sample['input']  # pylint: disable=redefined-builtin
            return input if isinstance(input, str) else tuple(input)

        merged = super()._merge_raw_datasets(seed)
        inputs = {to_hashable(merged[i]): i for i in range(len(merged))}
        return Subset(merged, sorted(inputs.values()))


class PromptOnlyCollator(CollatorBase):
    def __call__(self, samples: list[PromptOnlySample]) -> PromptOnlyBatch:
        input_ids = [sample['input_ids'] for sample in samples]
        attention_mask = [
            input_id.new_ones(input_id.size(), dtype=torch.bool) for input_id in input_ids
        ]

        input_ids = left_padding(input_ids, padding_value=self.pad_token_id)
        attention_mask = left_padding(attention_mask, padding_value=0)
        return {
            'input_ids': input_ids,  # size = (B, L)
            'attention_mask': attention_mask,  # size = (B, L)
        }

class PromptOnlyRawDataset(Dataset):

    def __init__(self, dataset_names_and_attributes, seed=42) -> None:
        if not isinstance(dataset_names_and_attributes, dict):
            dataset_names_and_attributes = tuple(dataset_names_and_attributes)
            dataset_names = [name for name, _ in dataset_names_and_attributes]
            if len(dataset_names) != len(set(dataset_names)):
                raise ValueError(
                    f'Dataset names should be unique, but got {dataset_names}.',
                )

        super().__init__()
        self.dataset_names_and_proportion: dict[str, float | Fraction] = {}
        self.raw_datasets = []
        for name, attributes in dict(dataset_names_and_attributes).items():
            if isinstance(attributes, float):
                kwargs = {'proportion': attributes}
            elif isinstance(attributes, dict):
                kwargs = dict(attributes)  # copy
            else:
                raise TypeError(
                    f'Dataset `{name}` attributes should be a float or a dict, '
                    f'got {type(attributes).__name__}.',
                )
            proportion = kwargs.pop('proportion', 1.0)
            if isinstance(proportion, Fraction):
                if not (proportion < 0 and proportion.denominator == 1):
                    raise ValueError(
                        f'Dataset `{name}` proportion should be a negative integer '
                        f'represents `num_samples / -1`, got {proportion}.',
                    )
            else:
                proportion = float(proportion)
                if proportion < 0.0:
                    raise ValueError(
                        f'Dataset `{name}` proportion should be no less than 0.0, '
                        f'got {proportion}.',
                    )
            if proportion == 0.0:
                continue
            raw_dataset = RawDataset.load(name, **kwargs)
            self.dataset_names_and_proportion[raw_dataset.NAME] = proportion
            self.raw_datasets.append(raw_dataset)
        merged_rawdata = self._merge_raw_datasets(seed=seed)
        self.data = [merged_rawdata[i] for i in range(len(merged_rawdata))]
        
    def __getitem__(self, index: int):
        """Get a tokenized data sample by index."""
        return self.data[index]

    def __len__(self) -> int:
        """Get the number of samples in the dataset."""
        return len(self.data)

    def _merge_raw_datasets(self, seed: int | None = None) -> Dataset[RawSample]:
        """Merge multiple raw datasets into one dataset."""
        def to_hashable(raw_sample: RawSample) -> Hashable:
            input = raw_sample['input']  # pylint: disable=redefined-builtin
            return input if isinstance(input, str) else tuple(input)
        
        if seed is None:
            seed = self.seed

        num_datasets = len(self.raw_datasets)
        datasets = []
        for raw_dataset, seed_seq in zip(
            self.raw_datasets,
            np.random.SeedSequence(seed).spawn(num_datasets),
        ):
            num_raw_samples = len(raw_dataset)
            proportion = self.dataset_names_and_proportion[raw_dataset.NAME]
            if isinstance(proportion, Fraction):
                num_selected_samples = abs(int(proportion))
            else:
                num_selected_samples = round(num_raw_samples * proportion)

            indices = []
            while num_selected_samples >= num_raw_samples > 0:
                indices.extend(range(num_raw_samples))
                num_selected_samples -= num_raw_samples
            if num_selected_samples > 0:
                rng = np.random.default_rng(seed_seq)
                indices.extend(
                    rng.choice(
                        len(raw_dataset),
                        size=num_selected_samples,
                        replace=False,
                    ).tolist(),
                )
            datasets.append(Subset(raw_dataset, indices))

        if num_datasets == 1:
            merged = datasets[0]
        else:
            merged = ConcatDataset(datasets)
    
        inputs = {to_hashable(merged[i]): i for i in range(len(merged))}
        return Subset(merged, sorted(inputs.values()))