# Copyright 2023-2024 PKU-Alignment Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""MOSS datasets for supervised instruction fine-tuning."""

from __future__ import annotations

import json
import pathlib
import re
import zipfile

from safe_rlhf.datasets.base import RawDataset, RawSample


__all__ = ['MOSS002SFT', 'MOSS003SFT']


class MOSS002SFT(RawDataset):
    NAME: str = 'moss-002-sft'
    PATTERN = re.compile(
        r"""
        ^
        \[(?P<from>(?P<human>Human)|(?P<assistant>MOSS))\]:
        \s*
        (?P<value>.*?)
        \s*
        (?(human)<eoh>|<eoa>)
        \s*
        """,
        flags=re.DOTALL | re.VERBOSE,
    )

    def __init__(self, path: str | None = None) -> None:
        if path is None:  # fnlp/moss-002-sft-data cannot load with `load_dataset`
            raise ValueError('moss-002-sft dataset requires a local path to the dataset.')

        path = pathlib.Path(path).expanduser().absolute()
        if not path.exists():
            raise ValueError('moss-002-sft dataset path does not exist.')
        if not path.is_dir():
            raise ValueError('moss-002-sft dataset path is not a directory.')

        data_files = sorted(path.glob('*.json'))
        self.data = []
        for file in data_files:
            with file.open(mode='rt', encoding='utf-8') as f:
                self.data.extend(json.load(f))

    def __getitem__(self, index: int) -> RawSample:
        data = self.data[index]
        plain_text = data['plain_text'].strip()
        if not plain_text.startswith(('[Human]:', '[MOSS]:')):
            raise ValueError(f'Invalid plain text: {plain_text}')

        dialogue = []
        text = plain_text
        while len(text) > 0:
            match = self.PATTERN.match(text)
            if match is None:
                raise ValueError(f'Invalid plain text: {plain_text}')
            if (match.group('human') is not None and len(dialogue) % 2 != 0) or (
                match.group('assistant') is not None and len(dialogue) % 2 != 1
            ):
                raise ValueError(f'Invalid plain text: {plain_text}')
            dialogue.append(match.group('value'))
            text = text[match.end() :]

        return RawSample(dialogue=dialogue)

    def __len__(self) -> int:
        return len(self.data)


class MOSS003SFT(RawDataset):
    NAME: str = 'moss-003-sft'

    def __init__(self, path: str | None = None) -> None:
        if path is None:  # fnlp/moss-003-sft-data cannot load with `load_dataset`
            raise ValueError('moss-003-sft dataset requires a local path to the dataset.')

        path = pathlib.Path(path).expanduser().absolute()
        if not path.exists():
            raise ValueError('moss-003-sft dataset path does not exist.')
        if not path.is_dir():
            raise ValueError('moss-003-sft dataset path is not a directory.')

        data_file = path / 'moss-003-sft-no-tools.jsonl'
        archive_file = path / 'moss-003-sft-no-tools.jsonl.zip'

        if not data_file.exists():
            if not archive_file.exists():
                raise ValueError('moss-003-sft dataset requires a local path to the dataset.')
            with zipfile.ZipFile(archive_file, mode='r') as archive:
                archive.extractall(path)

        self.data = []
        with data_file.open(mode='rt', encoding='utf-8') as f:
            for line in f:
                self.data.append(json.loads(line))

    def __getitem__(self, index: int) -> RawSample:
        data = self.data[index]
        num_turns = data['num_turns']
        chat = data['chat']
        dialogue = []
        for i in range(1, num_turns + 1):
            turn = chat[f'turn_{i}']
            dialogue.append(turn['Human'].replace('<|Human|>:', '').replace('<eoh>', '').strip())
            dialogue.append(turn['MOSS'].replace('<|MOSS|>', '').replace('<eom>', '').strip())
        return RawSample(dialogue=dialogue)

    def __len__(self) -> int:
        return len(self.data)
