# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from copy import deepcopy
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import json

from swift.utils import get_logger, use_hf_hub
from .preprocessor import DATASET_TYPE, AutoPreprocessor, MessagesPreprocessor

PreprocessFunc = Callable[..., DATASET_TYPE]
LoadFunction = Callable[..., DATASET_TYPE]
logger = get_logger()


@dataclass
class SubsetDataset:
    # `Name` is used for matching subsets of the dataset, and `subset` refers to the subset_name on the hub.
    name: Optional[str] = None
    # If set to None, then subset is set to subset_name.
    subset: str = 'default'

    # Higher priority. If set to None, the attributes of the DatasetMeta will be used.
    split: Optional[List[str]] = None
    preprocess_func: Optional[PreprocessFunc] = None

    # If the dataset specifies "all," weak subsets will be skipped.
    is_weak_subset: bool = False

    def __post_init__(self):
        if self.name is None:
            self.name = self.subset

    def set_default(self, dataset_meta: 'DatasetMeta') -> 'SubsetDataset':
        subset_dataset = deepcopy(self)
        for k in ['split', 'preprocess_func']:
            v = getattr(subset_dataset, k)
            if v is None:
                setattr(subset_dataset, k, deepcopy(getattr(dataset_meta, k)))
        return subset_dataset


@dataclass
class DatasetMeta:
    ms_dataset_id: Optional[str] = None
    hf_dataset_id: Optional[str] = None
    dataset_path: Optional[str] = None
    dataset_name: Optional[str] = None
    ms_revision: Optional[str] = None
    hf_revision: Optional[str] = None

    subsets: List[Union[SubsetDataset, str]] = field(default_factory=lambda: ['default'])
    # Applicable to all subsets.
    split: List[str] = field(default_factory=lambda: ['train'])
    # First perform column mapping, then proceed with the preprocess_func.
    preprocess_func: PreprocessFunc = field(default_factory=lambda: AutoPreprocessor())
    load_function: Optional[LoadFunction] = None

    tags: List[str] = field(default_factory=list)
    help: Optional[str] = None
    huge_dataset: bool = False

    def __post_init__(self):
        from .loader import DatasetLoader
        if self.load_function is None:
            self.load_function = DatasetLoader.load
        for i, subset in enumerate(self.subsets):
            if isinstance(subset, str):
                self.subsets[i] = SubsetDataset(subset=subset)


DATASET_MAPPING: Dict[Tuple[str, str, str], DatasetMeta] = {}


def get_dataset_list():
    datasets = []
    for key in DATASET_MAPPING:
        if use_hf_hub():
            if key[1]:
                datasets.append(key[1])
        else:
            if key[0]:
                datasets.append(key[0])
    return datasets


def register_dataset(dataset_meta: DatasetMeta, *, exist_ok: bool = False) -> None:
    """Register dataset

    Args:
        dataset_meta: The `DatasetMeta` info of the dataset.
        exist_ok: If the dataset id exists, raise error or update it.
    """
    if dataset_meta.dataset_name:
        dataset_name = dataset_meta.dataset_name
    else:
        dataset_name = dataset_meta.ms_dataset_id, dataset_meta.hf_dataset_id, dataset_meta.dataset_path
    if not exist_ok and dataset_name in DATASET_MAPPING:
        raise ValueError(f'The `{dataset_name}` has already been registered in the DATASET_MAPPING.')

    DATASET_MAPPING[dataset_name] = dataset_meta


def _preprocess_d_info(d_info: Dict[str, Any], *, base_dir: Optional[str] = None) -> Dict[str, Any]:
    d_info = deepcopy(d_info)

    columns = None
    if 'columns' in d_info:
        columns = d_info.pop('columns')

    if 'messages' in d_info:
        d_info['preprocess_func'] = MessagesPreprocessor(**d_info.pop('messages'), columns=columns)
    else:
        d_info['preprocess_func'] = AutoPreprocessor(columns=columns)

    if 'dataset_path' in d_info:
        dataset_path = d_info.pop('dataset_path')
        if base_dir is not None and not os.path.isabs(dataset_path):
            dataset_path = os.path.join(base_dir, dataset_path)
        dataset_path = os.path.abspath(os.path.expanduser(dataset_path))

        d_info['dataset_path'] = dataset_path

    if 'subsets' in d_info:
        subsets = d_info.pop('subsets')
        for i, subset in enumerate(subsets):
            if isinstance(subset, dict):
                subsets[i] = SubsetDataset(**_preprocess_d_info(subset))
        d_info['subsets'] = subsets
    return d_info


def _register_d_info(d_info: Dict[str, Any], *, base_dir: Optional[str] = None) -> DatasetMeta:
    """Register a single dataset to dataset mapping

    Args:
        d_info: The dataset info
    """
    d_info = _preprocess_d_info(d_info, base_dir=base_dir)
    dataset_meta = DatasetMeta(**d_info)
    register_dataset(dataset_meta)
    return dataset_meta


def register_dataset_info(dataset_info: Union[str, List[str], None] = None) -> List[DatasetMeta]:
    """Register dataset from the `dataset_info.json` or a custom dataset info file
    This is used to deal with the datasets defined in the json info file.

    Args:
        dataset_info: The dataset info path
    """
    # dataset_info_path: path, json or None
    if dataset_info is None:
        dataset_info = os.path.join(os.path.dirname(__file__), 'data', 'dataset_info.json')
    assert isinstance(dataset_info, (str, list))
    base_dir = None
    log_msg = None
    if isinstance(dataset_info, str):
        dataset_path = os.path.abspath(os.path.expanduser(dataset_info))
        if os.path.isfile(dataset_path):
            log_msg = dataset_path
            base_dir = os.path.dirname(dataset_path)
            with open(dataset_path, 'r', encoding='utf-8') as f:
                dataset_info = json.load(f)
        else:
            dataset_info = json.loads(dataset_info)  # json
    if len(dataset_info) == 0:
        return []
    res = []
    for d_info in dataset_info:
        res.append(_register_d_info(d_info, base_dir=base_dir))

    if log_msg is None:
        log_msg = dataset_info if len(dataset_info) < 5 else list(dataset_info.keys())
    logger.info(f'Successfully registered `{log_msg}`.')
    return res
