"""This Python code defines a class Dataset with methods for initializing, loading,
and manipulating datasets from different backends such as Hugging Face and JSON.
 
The `Dataset` class includes methods for loading datasets from a dictionary and a Hugging
Face dataset, mapping datasets, and retrieving the backend dataset and arguments.
"""

# Importing necessary libraries and modules
import copy
import json

from cmath import e
from pathlib import Path
from typing import Optional
import hashlib

from datasets import load_dataset
from datasets import Dataset as HFDataset

DATASET_TYPES = [
    "text_only",
    "text2text",
    "float_only",
    "image_text",
]

INSTANCE_FIELDS_MAP = {
    "text_only": ["text"],
    "text2text": ["input", "output"],
    "conversation": ["messages"], # system, tools and conversation_id are optional
    "float_only": ["value"],
    "image_text": ["images", "text"],
}

KEY_TYPE = "type"
KEY_INSTANCES = "instances"

class Dataset_ft:
    r"""
    Initializes the Dataset object with the given parameters.

    Parameters
    ------------
    data_args : DatasetArguments object.
        Contains the arguments required to load the dataset.

    backend : str,  default="huggingface"
        A string representing the dataset backend. Defaults to "huggingface".
    
    args : Optional.
        Positional arguments.
    
    kwargs : Optional.
        Keyword arguments.
    """
    def __init__(self, data_args=None, backend: str="huggingface", *args, **kwargs):
        self.data_args = data_args
        self.backend = backend
        self.backend_dataset = None
        self.type = None        # Original type of the dataset
        self.dataset_path = data_args.dataset_path

        if data_args.dataset_path is None:
            return

        if backend == "huggingface":
            # data_files = [self.dataset_path]
            data_files = [
                x.absolute().as_posix()
                    for x in Path(self.dataset_path).glob("*.json")
            ]

            # Iterate through all the files and ensure they have the same data type
            for single_file in data_files:
                with open(single_file) as fin:
                    json_data = json.load(fin)
                    if KEY_TYPE not in json_data.keys():
                        raise ValueError(
                            f'"{KEY_TYPE}" field must be specified for data, e.g.'
                            '{\n'
                            f'   "{KEY_TYPE}: "text_only",\n'
                            f'   "{KEY_INSTANCES}": [\n'
                            '       { "text": "Sentence 1: This is a sentence." }\n'
                            '       { "text": "Sentence 2: This is another sentence." }\n'
                            f'   ]\n'
                            '}'
                        )
                    if self.type is None:
                        self.type = json_data[KEY_TYPE]
                    elif self.type != json_data[KEY_TYPE]:
                        raise ValueError(
                            'All task files must have same data types. Previous'
                            f' files have type "{self.type}", but in file'
                            f' {single_file}, it has type "{self.type}".'
                        )

            # Load the dataset using the HuggingFace dataset library
            extensions = "json"
            raw_dataset = load_dataset(
                extensions,
                data_files=data_files,
                field=KEY_INSTANCES,
                split="train",
                use_auth_token=None,
            )
            self.backend_dataset = raw_dataset
            self._check_data_format()
        elif backend == "json":
            # TODO (@Jiachun)
            pass
        elif backend == "custom_multi_modal":
            # FIXME refactor the backend name
            raw_dataset = CustomMultiModalDataset(self.dataset_path, data_args)
            self.backend_dataset = raw_dataset
        else:
            raise NotImplementedError(f'Unsupported dataset backend "{backend}"')

    def __len__(self):
        return len(self.backend_dataset)

    def _check_data_format(self):
        """Checks if data type and data structure matches

        Raise messages with hints if not matched.
        """
        data_dict = self.to_dict()
        if KEY_TYPE not in data_dict:
            raise ValueError(
                f'"{KEY_TYPE}" must be provided to initialize a dataset,'
                f' e.g.\n'
            )
        if KEY_INSTANCES not in data_dict:
            raise ValueError(
                f'"{KEY_INSTANCES}" must be provided to initialize a'
                f' dataset, e.g.\n'
            )

        data_type = data_dict[KEY_TYPE]
        fields = self.get_backend_dataset().features
        correct_fields = INSTANCE_FIELDS_MAP[data_type]
        # TODO: this can not guarantee every instance has correct fields.
        if set(fields) != set(correct_fields):
            if data_type == "conversation":
                if "messages" not in fields:
                    raise ValueError(
                        f'Conversation dataset should have "messages" field'
                        f' but got {list(fields)}'
                    )
            else:
                raise ValueError(
                    f'Data instance fields incorrect'
                    f' {list(fields)}: should be {list(correct_fields)}.'
                )


    def from_dict(self, dict_obj: dict, *args, **kwargs):
        r"""
        Create a Dataset object from a dictionary.

        Return a Dataset given a dict with format:
            {
                "type": TYPE,
                "instances": [
                    {
                        "key_1": VALUE_1.1,
                        "key_2": VALUE_1.2,
                        ...
                    },
                    {
                        "key_1": VALUE_2.1,
                        "key_2": VALUE_2.2,
                        ...
                    },
                    ...
                ]
            }

        Parameters
        -----------

        dict_obj : dict.
            A dictionary containing the dataset information.
        
        args : Optional.
            Positional arguments.
        
        kwargs : Optional.
            Keyword arguments.

        Returns
        ---------

        self : Dataset object.
        """
        if self.backend == "huggingface":
            if KEY_TYPE not in dict_obj:
                raise ValueError(
                    f'"{KEY_TYPE}" must be provided to initialize a dataset,'
                    f' e.g.\n'
                )
            if KEY_INSTANCES not in dict_obj:
                raise ValueError(
                    f'"{KEY_INSTANCES}" must be provided to initialize a'
                    f' dataset, e.g.\n'
                )

            self.type = dict_obj[KEY_TYPE]
            if not self.type in INSTANCE_FIELDS_MAP:
                raise ValueError(f'type "{self.type}" is not supported')

            correct_fields = INSTANCE_FIELDS_MAP[self.type]

            for i, instance in enumerate(dict_obj[KEY_INSTANCES]):
                fields = instance.keys()
                if set(fields) != set(correct_fields):
                    if self.type == "conversation":
                        if "messages" not in fields:
                            raise ValueError(
                                f'Conversation dataset should have "messages" field'
                                f' but got {list(fields)}'
                            )
                    else:
                        raise ValueError(
                            f'data instance fields incorrect'
                            f' {list(fields)}: should be {list(correct_fields)}.\n'
                            f'The bad instance triggers the error, the {i}-th instance:\n'
                            f'    {instance}'
                    )

            try:
                hf_dict = {}
                if len(dict_obj[KEY_INSTANCES]) > 0:
                    for key in dict_obj[KEY_INSTANCES][0].keys():
                        hf_dict[key] = [
                            instance[key] for instance in dict_obj[KEY_INSTANCES]
                        ]

                self.backend_dataset = HFDataset.from_dict(hf_dict, *args, **kwargs)
            except AttributeError as ex:
                raise ValueError(
                    f"Error occurs: {ex}. Failed to convert dict to"
                    f" \"{self.type}\" dataset," f" the standard format is as"
                    f" follows:\n"
                )
            self._check_data_format()

            return self
        elif self.backend == "dict":
            self.backend_dataset = dict_obj
            self.type = dict_obj[KEY_TYPE]
            return self
        else:
            raise NotImplementedError(
                f'Currently .from_dict is not supported for backend "{backend}"'
            )


    @classmethod
    def create_from_dict(cls, dict_obj, *args, **kwargs):
        r"""
        Returns
        --------

        Returns a Dataset object given a dict.
        """
        empty_data_args = DatasetArguments(dataset_path=None)
        dataset = Dataset(empty_data_args)
        return dataset.from_dict(dict_obj)


    def to_dict(self):
        r"""
        Returns
        ---------

        Return a dict represents the dataset:
            {
                "type": TYPE,
                "instances": [
                    {
                        "key_1": VALUE_1.1,
                        "key_2": VALUE_1.2,
                        ...
                    },
                    {
                        "key_1": VALUE_2.1,
                        "key_2": VALUE_2.2,
                        ...
                    },
                    ...
                ]
            }

        A python dict object represents the content of this dataset.
        """
        if self.backend == "huggingface":
            dict_obj = {}
            dict_obj[KEY_TYPE] = self.get_type()
            hf_dict = self.backend_dataset.to_dict()
            dict_obj[KEY_INSTANCES] = []

            first_key = None
            for key in hf_dict.keys():
                first_key = key
                break

            if first_key is not None:
                num_instances = len(hf_dict[first_key])
                dict_obj[KEY_INSTANCES] = [
                    {
                        key: hf_dict[key][i] for key in hf_dict.keys()
                    }
                    for i in range(num_instances)
                ]

            return dict_obj
        elif self.backend == "dict":
            dict_obj = self.backend_dataset
            return dict_obj
        else:
            raise NotImplementedError(
                f'Current .to_dict is not supported for backend "{backend}"'
            )


    def to_list(self):
        """Returns a list of instances."""
        if self.backend == "huggingface":
            instance_list = [self.backend_dataset.__getitem__(idx)
                             for idx in range(len(self.backend_dataset))]
            return instance_list
        elif self.backend == "dict":
            instance_list = copy.deepcopy(self.backend_dataset[KEY_INSTANCES])
            # TODO: should be a list of instances, instance should be huggingface datasets row format
            return instance_list
        else:
            raise NotImplementedError(
                f'Current .to_list is not supported for backend "{backend}"'
            )


    def map(self, *args, **kwargs):
        r"""
        Parameters
        ------------
        args : Optional.
            Positional arguments.
        
        kwargs : Optional.
            Keyword arguments.

        Returns
        ---------

        self : Dataset object.
        """
        # If the dataset uses Hugging Face as the backend, 
        # call the `map()` function of the Hugging Face backend dataset
        if self.backend == "huggingface":
            # Set the mapped dataset as the backend dataset of the current dataset
            mapped_backend_dataset = self.backend_dataset.map(*args, **kwargs)
            self.backend_dataset = mapped_backend_dataset
            return self
        else:
            # If the backend is not Hugging Face, raise a NotImplementedError
            raise NotImplementedError(
                f'Currently .map is not supported for backend "{backend}"'
            )


    def get_backend(self) -> Optional[str]:
        r"""
        Returns
        ---------

        self.backend
        """
        return self.backend


    def get_backend_dataset(self):
        r"""
        Returns
        ---------

        self.backend_dataset
        """
        return self.backend_dataset


    def get_fingerprint(self):
        r"""
        Returns
        ---------

        Fingerprint of the backend_dataset which controls the cache
        """
        return self.backend_dataset._fingerprint

    
    def get_data_args(self):
        r"""
        Returns
        ---------

        self.data_args
        """
        return self.data_args


    def get_type(self):
        r"""
        Returns
        ---------

        self.type
        """
        return self.type

def tokenize(dataset, model_args, tokenizer, add_special_tokens=True, *args, **kwargs):
    """
    Tokenize the full dataset.

    Parameters
    ------------
    dataset : lmflow.datasets.Dataset.

    args : Optional.
        Positional arguments.

    kwargs : Optional.
        Keyword arguments.

    Returns
    ------------
    tokenized_datasets :
        The tokenized dataset, without any leading or trailing special
        tokens (normally they are Begin-Of-Sentence or End-Of-Sentence
        tokens).
    """
    # Preprocessing the datasets.
    # First we tokenize all the texts.
    if dataset.get_backend() != "huggingface":
        raise NotImplementedError(
            "tokenization of datasets with non-huggingface backend are"
            "not supported yet"
        )

    dataset_type = dataset.get_type()
    raw_datasets = dataset
    hf_raw_datasets = dataset.get_backend_dataset()
    column_names = list(hf_raw_datasets.features)

    # since this will be pickled to avoid _LazyModule error in Hasher force
    # logger loading before tokenize_function
    data_args = raw_datasets.get_data_args()

    # Requires three types of information for tokenizing different datasets
    #   1) Which fields require tokenization, e.g.
    #        "text2float": "text", but not "float"
    #        "text2text": both "input" and "output"
    #   2) How will there tokenized sequence concatenated together, e.g.
    #        "text_only": "text" -> "text"
    #        "text2text": "input", "output" -> "input" + "output"
    #   3) Which fields require loss in final computation, e.g.
    #        "text_only": "text"
    #        "text2text": "output" only
    tokenized_column_order = None  # Handles 1) and 2)
    label_columns = None  # Handles 3)
    if dataset_type == "text_only":
        tokenized_column_order = ["text"]
        label_columns = ["text"]
    elif dataset_type == "text2text":
        tokenized_column_order = ["input", "output"]
        label_columns = ["output"]
        add_special_tokens = False
    else:
        raise NotImplementedError(
            f"dataset type \"{dataset_type}\" is not supported, currently"
            " only support following data types:\n"
        )

    # Whether to truncate long sequences to fit into max_length
    use_truncation = False
    #if model_args.use_lora or data_args.disable_group_texts:
    #    use_truncation = True

    def tokenize_function(examples):
        num_example = len(examples[column_names[0]])
        token_dict = {
            "input_ids": [[] for _ in range(num_example)],
            "attention_mask": [[] for _ in range(num_example)],
            "labels": [[] for _ in range(num_example)],
        }
        
        for column_name in tokenized_column_order:
            encoding = tokenizer(
                examples[column_name],
                # add_special_tokens=add_special_tokens,
                truncation=use_truncation,
            )
            # encoding = tokenizer(
            #     examples[column_name],
            #     truncation=True,
            #     max_length=256,
            #     padding=False,
            #     return_tensors=None,
            # )

            if column_name in label_columns:
                labels = encoding["input_ids"].copy()
            else:
                labels = [
                    [-100] * len(encoding["input_ids"][i])
                    for i in range(num_example)
                ]

            for i in range(num_example):
                token_dict["input_ids"][i].extend(
                    encoding["input_ids"][i]
                )
                token_dict["attention_mask"][i].extend(
                    encoding["attention_mask"][i]
                )
                token_dict["labels"][i].extend(labels[i])

        if data_args.disable_group_texts:
            for i in range(num_example):
                block_size = data_args.block_size
                max_length = min(block_size, tokenizer.model_max_length)
                pad_length = max_length - len(token_dict["input_ids"][i])
                if pad_length < 0:
                    # Truncates too long samples
                    for key in ["input_ids", "attention_mask", "labels"]:
                        token_dict[key][i] = token_dict[key][i][:pad_length]
                else:
                    # Pads too short samples
                    pad_token_id = tokenizer.pad_token_id
                    token_dict["input_ids"][i].extend(
                        [pad_token_id for _ in range(pad_length)]
                    )
                    token_dict["attention_mask"][i].extend(
                        [0 for _ in range(pad_length)]
                    )
                    token_dict["labels"][i].extend(
                        [-100 for _ in range(pad_length)]
                    )

        return token_dict

    if not data_args.streaming:

        tokenized_datasets = raw_datasets.map(
            tokenize_function,
            batched=True,
            num_proc=data_args.preprocessing_num_workers,
            remove_columns=column_names,
            load_from_cache_file=not data_args.overwrite_cache,
            desc="Running tokenizer on dataset",
        )
    else:
        tokenized_datasets = raw_datasets.map(
            tokenize_function,
            batched=True,
            remove_columns=column_names,
        )
    return tokenized_datasets

