import re
import os
import json
import math
import random
import datasets
from tqdm import tqdm
from functools import partial
from glob import glob
from contextlib import nullcontext
from transformers.utils import logging
from src import apply_chat_template, add_eos, split_file_dir_name_ext

logger = logging.get_logger(__name__)


# RETRIEVAL_CAND = [(1024,1), (512,2), (256,4), (128,8), (512,1), (256,2), (128,4)]
RETRIEVAL_CAND = [(1024,1)]


class Data:
    def _process_language_modeling(data, indices, tokenizer, min_length, max_length):
        outputs = {'input_ids': [], 'attention_mask': [], "labels": [], "length": [], "index": []}

        for i, text in enumerate(data['text']):
            # truncate text for faster processing
            encoded = tokenizer(text)
            if len(encoded["input_ids"]) < min_length:
                continue
            elif len(encoded['input_ids']) < max_length:
                encoded = add_eos(encoded, tokenizer.eos_token_id)
            else:
                for k, v in encoded.items():
                    encoded[k] = v[:max_length]

            encoded["labels"] = encoded["input_ids"].copy()

            for k, v in encoded.items():
                outputs[k].append(v)
            # length is required for grouping
            outputs["length"].append(len(encoded['input_ids']))
            outputs["index"].append(indices[i])

        return outputs

    def _process_instruction_tuning(data, indices, tokenizer, chat_template, min_length, max_length, eval_mode=False):
        outputs = {'input_ids': [], 'attention_mask': [], "labels": [], "length": [], "index": []}

        for i, source in enumerate(data['conversations']):
            if source[0]["role"] != 'user':
                # Skip the first one if it is not from user
                source = source[1:]

            # NOTE: in evaluation, we only use the first turn in the conversation
            if eval_mode:
                # a string (the expected output from the assistant)
                if len(source) > 1:
                    labels = source[1]['content']
                else:
                    labels = None
                source = source[:1]

            encoded = apply_chat_template(
                chat_template, 
                source, 
                tokenizer=tokenizer, 
                # only return labels in evaluation mode
                return_labels=not eval_mode,
                add_generation_prompt=eval_mode, 
            ).encoded

            # skip data that not fall in between min_length and max_length
            if len(encoded["input_ids"]) < min_length:
                continue
            if len(encoded["input_ids"]) > max_length:
                continue

            if eval_mode:
                encoded["labels"] = labels

            for k, v in encoded.items():
                outputs[k].append(v)
            outputs['length'].append(len(encoded['input_ids']))
            outputs['index'].append(indices[i])

        return outputs

    def prepare_train_data(data_files=None, tokenizer=None, max_length=4096, min_length=512, chat_template="vicuna", max_sample_num=None, seed=42, cache_dir=None, load_from_cache_file=None):
        if data_files is None:
            return None

        if isinstance(data_files, list):
            logger.info(f"Loading training data from {data_files}...")
        elif isinstance(data_files, str):
            logger.info(f"Loading training data from {data_files}...")
            data_files = [data_files]
        else:
            raise ValueError(f"Invalid training data {data_files}!")

        data_2_num_sample = {}
        for data_file in data_files:
            match = re.search("\[(\d*)\]", data_file)
            if match:
                max_sample_num = int(match.group(1))
                data_file = re.sub("\[(\d*)\]", "", data_file)
            else:
                max_sample_num = None
            data_2_num_sample[data_file] = max_sample_num   
        
        random.seed(seed)
        
        train_datasets = []
        for data_file, max_sample_num in data_2_num_sample.items():

            if os.path.isdir(data_file) and os.path.exists(os.path.join(data_file, "dataset_info.json")):
                # the dataset may be save_to_disk in advance
                dataset = datasets.load_from_disk(data_file)

            else:
                # the dataset is a json file
                dataset = datasets.load_dataset('json', data_files=data_file, split='train', cache_dir=cache_dir)

                column_names = dataset.column_names
                if "text" in column_names:
                    process_fn = partial(
                        Data._process_language_modeling, 
                        tokenizer=tokenizer, 
                        min_length=min_length, 
                        max_length=max_length
                    )
                elif "conversations" in column_names:
                    process_fn = partial(
                        Data._process_instruction_tuning, 
                        tokenizer=tokenizer, 
                        chat_template=chat_template, 
                        min_length=min_length, 
                        max_length=max_length
                    )
                else:
                    raise ValueError(f"Found neither 'text' nor 'conversations' in the training data!")

                dataset = dataset.map(process_fn, batched=True, num_proc=32, remove_columns=dataset.column_names, batch_size=32, with_indices=True, load_from_cache_file=load_from_cache_file)

            if max_sample_num is not None and len(dataset) > max_sample_num:
                dataset = dataset.train_test_split(max_sample_num, seed=seed)["test"]

            # index column is useless in training
            if "index" in dataset.column_names:
                dataset = dataset.remove_columns(["index"])

            train_datasets.append(dataset)

        dataset = datasets.concatenate_datasets(train_datasets)

        return dataset

    def prepare_eval_data(data_files=None, tokenizer=None, max_length=4096, min_length=512, chat_template="vicuna", max_eval_num=None, cache_dir=None, seed=42, load_from_cache_file=None):
        if data_files is None:
            return None

        random.seed(seed)

        if max_eval_num is not None:
            dataset = datasets.load_dataset('json', data_files=data_files, split=f'train[:{max_eval_num}]', cache_dir=cache_dir)
        else:
            dataset = datasets.load_dataset('json', data_files=data_files, split='train', cache_dir=cache_dir)

        column_names = dataset.column_names
        if "text" in column_names:
            process_fn = partial(
                Data._process_language_modeling, 
                tokenizer=tokenizer, 
                min_length=min_length, 
                max_length=max_length
            )
        elif "conversations" in column_names:
            process_fn = partial(
                Data._process_instruction_tuning, 
                tokenizer=tokenizer, 
                chat_template=chat_template, 
                min_length=min_length, 
                max_length=max_length,
                eval_mode=True,
            )
        else:
            raise ValueError(f"Found neither 'text' nor 'conversations' in the training data!")

        dataset = dataset.map(process_fn, batched=True, num_proc=32, remove_columns=dataset.column_names, with_indices=True, load_from_cache_file=load_from_cache_file)
        return dataset