import logging
import os
from typing import Union, List
import datasets
import torch
from datasets import load_dataset, concatenate_datasets
import transformers
import random
import numpy as np

IGNORE_INDEX = -100

logger = logging.getLogger('__name__')
PROMPT_TEMPLATE = [
{
    "prompt_input": (
        "Below is an instruction that describes a task, paired with an input that provides further context. "
        "Write a response that appropriately completes the request.\n\n"
        "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
    ),
    "prompt_no_input": (
        "Below is an instruction that describes a task. "
        "Write a response that appropriately completes the request.\n\n"
        "### Instruction:\n{instruction}\n\n### Response:"
    ),
},
{
    "prompt_input": (
        "You are supposed to follow an instruction, and then the input to generate proper response.\n\n"
        "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
    ),
    "prompt_no_input": (
        "You are supposed to follow an instruction to generate proper response."
        "### Instruction:\n{instruction}\n\n### Response:"
    ),
},
{
    "prompt_input": (
        "Please follow the instruction and input to give a response.\n\n"
        "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
    ),
    "prompt_no_input": (
        "Please follow the instruction to give a response."
        "### Instruction:\n{instruction}\n\n### Response:"
    ),
},
{
    "prompt_input": (
        "You are an expert, please listen to human instruction and input to generate the response.\n\n"
        "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
    ),
    "prompt_no_input": (
        "You are an expert, please listen to human instruction to generate the response.\n\n"
        "### Instruction:\n{instruction}\n\n### Response:"
    ),
},
{
    "prompt_input": (
        "Let's follow the instruction to respond to an input.\n\n"
        "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
    ),
    "prompt_no_input": (
        "Let's follow the instruction to generate a response.\n\n"
        "### Instruction:\n{instruction}\n\n### Response:"
    ),
},
{
    "prompt_input": (
        "The instruction is a description of the task. You need to follow that and respond to the paired input.\n\n"
        "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
    ),
    "prompt_no_input": (
        "The instruction is a description of the task. You need to follow that and respond.\n\n"
        "### Instruction:\n{instruction}\n\n### Response:"
    ),
},
{
    "prompt_input": (
        "Below is an instruction that describes a task, paired with an input that provides further context. "
        "Write a response that appropriately completes the request.\n\n"
        "Instruction:\n{instruction}\n\nInput:\n{input}\n\nResponse:"
    ),
    "prompt_no_input": (
        "Below is an instruction that describes a task. "
        "Write a response that appropriately completes the request.\n\n"
        "Instruction:\n{instruction}\n\nResponse:"
    ),
},
{
    "prompt_input": (
        "You are supposed to follow an instruction, and then the input to generate proper response.\n\n"
        "#Instruction:\n{instruction}\n\nInput:\n{input}\n\nResponse:"
    ),
    "prompt_no_input": (
        "You are supposed to follow an instruction to generate proper response."
        "Instruction:\n{instruction}\n\nResponse:"
    ),
},
{
    "prompt_input": (
        "Please follow the instruction and input to give a response.\n\n"
        "Instruction:\n{instruction}\n\nInput:\n{input}\n\nResponse:"
    ),
    "prompt_no_input": (
        "Please follow the instruction to give a response."
        "Instruction:\n{instruction}\n\nResponse:"
    ),
},
{
    "prompt_input": (
        "You are an expert, please listen to human instruction and input to generate the response.\n\n"
        "Instruction:\n{instruction}\n\nInput:\n{input}\n\nResponse:"
    ),
    "prompt_no_input": (
        "You are an expert, please listen to human instruction to generate the response.\n\n"
        "Instruction:\n{instruction}\n\nResponse:"
    ),
},
{
    "prompt_input": (
        "Let's follow the instruction to respond to an input.\n\n"
        "Instruction:\n{instruction}\n\nInput:\n{input}\n\nResponse:"
    ),
    "prompt_no_input": (
        "Let's follow the instruction to generate a response.\n\n"
        "Instruction:\n{instruction}\n\nResponse:"
    ),
},
{
    "prompt_input": (
        "The instruction is a description of the task. You need to follow that and respond to the paired input.\n\n"
        "Instruction:\n{instruction}\n\nInput:\n{input}\n\nResponse:"
    ),
    "prompt_no_input": (
        "The instruction is a description of the task. You need to follow that and respond.\n\n"
        "Instruction:\n{instruction}\n\nResponse:"
    ),
},
]

PROMPT_TEMPLATE_SINGLE = {
    "prompt_input": (
        "Below is an instruction that describes a task, paired with an input that provides further context. "
        "Write a response that appropriately completes the request.\n\n"
        "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
    ),
    "prompt_no_input": (
        "Below is an instruction that describes a task. "
        "Write a response that appropriately completes the request.\n\n"
        "### Instruction:\n{instruction}\n\n### Response:"
    ),
}

# DEFAULT_SYSTEM_PROMPT = """You are a helpful assistant. 你是一个乐于助人的助手。"""
# system_format='<|start_header_id|>system<|end_header_id|>\n\n{content}<|eot_id|>'
# user_format='<|start_header_id|>user<|end_header_id|>\n\n{content}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n'
# assistant_format='{content}<|eot_id|>'

def build_instruction_dataset(data_path: Union[List[str],str],
                tokenizer: transformers.PreTrainedTokenizer,
                max_seq_length: int, data_cache_dir = None,
                preprocessing_num_workers = None,
                ):

    def tokenization(examples):
        sources = []
        targets = []
        for instruction, input_text, output in zip(examples['instruction'],examples['input'],examples['output']):
            if input_text is not None and input_text !="":
                instruction = instruction+'\n'+input_text
            source = PROMPT_TEMPLATE_SINGLE["prompt_no_input"].format(instruction=instruction)#system_format.format(content=DEFAULT_SYSTEM_PROMPT) + user_format.format(content=instruction)
            target = f"{output}{tokenizer.eos_token}"#assistant_format.format(content=output)

            sources.append(source)
            targets.append(target)

        tokenized_sources = tokenizer(sources, return_attention_mask=False, add_special_tokens=False)
        tokenized_targets = tokenizer(targets, return_attention_mask=False, add_special_tokens=False)

        all_input_ids = []
        all_labels = []
        for s,t in zip(tokenized_sources['input_ids'],tokenized_targets['input_ids']):
            input_ids = torch.LongTensor(s + t)[:max_seq_length]
            labels = torch.LongTensor([IGNORE_INDEX] * len(s) + t)[:max_seq_length]
            all_input_ids.append(input_ids)
            all_labels.append(labels)

        results = {'input_ids':all_input_ids, 'labels': all_labels}
        return results


    logging.warning("building dataset...")
    all_datasets = []

    if not isinstance(data_path,(list,tuple)):
        data_path = [data_path]
    for file in data_path:

        if data_cache_dir is None:
            data_cache_dir = str(os.path.dirname(file))
        cache_path = os.path.join(data_cache_dir,os.path.basename(file).split('.')[0]+f"_{max_seq_length}")
        os.makedirs(cache_path, exist_ok=True)
        try:
            processed_dataset = datasets.load_from_disk(cache_path)
            logger.info(f'training datasets-{file} has been loaded from disk')
        except Exception:
            raw_dataset = load_dataset("json", data_files=file, cache_dir=cache_path)
            tokenization_func = tokenization
            tokenized_dataset = raw_dataset.map(
                tokenization_func,
                batched=True,
                num_proc=preprocessing_num_workers,
                remove_columns=["instruction","input","output"],
                keep_in_memory=False,
                desc="preprocessing on dataset",
            )
            processed_dataset = tokenized_dataset
            processed_dataset.save_to_disk(cache_path)
        processed_dataset.set_format('torch')
        all_datasets.append(processed_dataset['train'])
    all_datasets = concatenate_datasets(all_datasets)
    return all_datasets
