# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
import os
import sys
from typing import TYPE_CHECKING, Literal, Optional, Union
from functools import partial

import numpy as np
from datasets import load_dataset, load_from_disk

# from ..extras.constants import FILEEXT2TYPE
from ..extras.logging import get_logger
from ..extras.misc import has_tokenized_data
from .aligner import align_dataset
from .data_utils import merge_dataset
from .parser import get_dataset_attr
# from .preprocess import get_preprocess_and_print_func
from .template import get_template_and_fix_tokenizer

from .processors.mmsupervised import (
    preprocess_mmsupervised_dataset,
    print_supervised_dataset_example,
    encode_graph_pyg
)

if TYPE_CHECKING:
    from datasets import Dataset, IterableDataset
    from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments

    from ..hparams import DataArguments, ModelArguments
    from .parser import DatasetAttr


logger = get_logger(__name__)


def load_single_dataset(
    dataset_attr: "DatasetAttr",
    model_args: "ModelArguments",
    data_args: "DataArguments",
    training_args: "Seq2SeqTrainingArguments",
) -> Union["Dataset", "IterableDataset"]:
    logger.info("Loading dataset {}...".format(dataset_attr))

    data_files = []
    assert dataset_attr.load_from == "file"

    data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
    data_files.append(data_path)
    data_path = data_path.split(".")[-1]

    if "trust_remote_code" in inspect.signature(load_dataset).parameters:  # for datasets==2.16.0
        kwargs = {"trust_remote_code": True}
    else:
        kwargs = {}

    dataset = load_dataset(
        path=data_path,
        name=None,
        data_dir=None,
        data_files=data_files,
        split=data_args.split,
        cache_dir=model_args.cache_dir,
        token=model_args.hf_hub_token,
        streaming=False,
        **kwargs,
    )

    # max_samples = 1000
    # dataset = dataset.select(range(max_samples))
    # print('Temporarily truncate dataset to {} samples.'.format(max_samples))
    
    converted_dataset, mol_id_to_smiles = align_dataset(dataset, dataset_attr, data_args, training_args)
    return converted_dataset, mol_id_to_smiles

def get_dataset(
    model_args: "ModelArguments",
    data_args: "DataArguments",
    training_args: "Seq2SeqTrainingArguments",
    tokenizer: "PreTrainedTokenizer",
) -> Union["Dataset", "IterableDataset"]:
    
    template = get_template_and_fix_tokenizer(tokenizer, data_args.template, data_args.tool_format)
    if data_args.train_on_prompt and template.efficient_eos:
        raise ValueError("Current template does not support `train_on_prompt`.")
    print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)

    # Load tokenized dataset
    if data_args.tokenized_path is not None:
        if has_tokenized_data(data_args.tokenized_path):
            mol_id_to_pyg = encode_graph_pyg(data_path=data_args.tokenized_path)
            logger.warning("Loading dataset from disk will ignore other data arguments.")
            dataset = load_from_disk(data_args.tokenized_path)
            logger.info("Loaded tokenized dataset from {}.".format(data_args.tokenized_path))
            # print_function(next(iter(dataset)))
            data_iter = iter(dataset)
            print_function(next(data_iter))
            return mol_id_to_pyg, dataset
        
    # Load tokenized dataset
    with training_args.main_process_first(desc="load dataset"):
        # current only support one dataset
        dataset_attr = get_dataset_attr(data_args)
        dataset, mol_id_to_smiles = load_single_dataset(dataset_attr, model_args, data_args, training_args)
        print(f"Total unique molecules: {len(mol_id_to_smiles)}")

        print('Finish dataset loading')
        for feature_name, feature in dataset.features.items():
            print(f"Feature Name: {feature_name}")
            print(f"Values: {feature}")
            print()

    with training_args.main_process_first(desc="pre-process dataset"):
        preprocess_func = partial(
            preprocess_mmsupervised_dataset,
            template=template,
            tokenizer=tokenizer,
            data_args=data_args,
        )

        column_names = list(next(iter(dataset)).keys())
        kwargs = {}
        kwargs = dict(
            num_proc=data_args.preprocessing_num_workers,
            load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0),
            desc="Running tokenizer on dataset",
        )
        
        dataset = dataset.map(preprocess_func, batched=True, remove_columns=column_names, **kwargs)
        
        if data_args.tokenized_path is not None:
            if training_args.should_save:
                dataset.save_to_disk(data_args.tokenized_path)
                mol_id_to_pyg = encode_graph_pyg(data_path=data_args.tokenized_path, mol_id_to_smiles=mol_id_to_smiles)
                logger.info("Tokenized dataset saved at {}.".format(data_args.tokenized_path))
                logger.info("Please restart the training with `tokenized_path: {}`.".format(data_args.tokenized_path))
            sys.exit(0)
        else:
            mol_id_to_pyg = encode_graph_pyg(mol_id_to_smiles=mol_id_to_smiles)

        if training_args.should_log:
            try:
                print_function(next(iter(dataset)))
            except StopIteration:
                raise RuntimeError("Cannot find valid samples.")
                
        return mol_id_to_pyg, dataset
