import copy
import logging
import os
import sys
import utils
from datasets import load_dataset
from dataclasses import dataclass, field
from typing import Optional, Dict, Sequence

import numpy as np
import torch
import transformers
from torch.utils.data import Dataset, IterableDataset

from tokenizer import IGNORE_INDEX, DEFAULT_PAD_TOKEN, DEFAULT_EOS_TOKEN, DEFAULT_BOS_TOKEN, DEFAULT_UNK_TOKEN

PROMPT_DICT = {
    "prompt_input":
    ("Below is an instruction that describes a task, paired with an input that provides further context. "
     "Write a response that appropriately completes the request.\n\n"
     "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response: "
     ),
    "prompt_no_input":
    ("Below is an instruction that describes a task. "
     "Write a response that appropriately completes the request.\n\n"
     "### Instruction:\n{instruction}\n\n### Response: "),
    "prompt_belle": ("### Instruction:\n{Human}\n\n### Response:"),
}


def _tokenize_fn(strings: Sequence[str],
                 tokenizer: transformers.PreTrainedTokenizer,
                 use_pipeline=False) -> Dict:
    """Tokenize a list of strings."""
    padding_method = "longest"
    print(f"using padding {padding_method} {tokenizer.model_max_length}")

    tokenized_list = [
        tokenizer(
            text,
            return_tensors="pt",
            padding=padding_method,
            max_length=tokenizer.model_max_length,
            truncation=True,
        ) for text in strings
    ]
    input_ids = labels = [
        tokenized.input_ids[0] for tokenized in tokenized_list
    ]
    input_ids_lens = labels_lens = [
        tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()
        for tokenized in tokenized_list
    ]
    return dict(
        input_ids=input_ids,
        labels=labels,
        input_ids_lens=input_ids_lens,
        labels_lens=labels_lens,
    )


@dataclass
class DataCollatorForSupervisedDataset(object):
    """Collate examples for supervised fine-tuning."""

    tokenizer: transformers.PreTrainedTokenizer
    pad_to_max_len: bool

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        input_ids, labels = tuple([instance[key] for instance in instances]
                                  for key in ("input_ids", "labels"))

        if self.pad_to_max_len:
            input_ids.append(
                torch.zeros([self.tokenizer.model_max_length],
                            dtype=input_ids[0].dtype))
            labels.append(
                torch.zeros([self.tokenizer.model_max_length],
                            dtype=labels[0].dtype))

        input_ids = torch.nn.utils.rnn.pad_sequence(
            input_ids,
            batch_first=True,
            padding_value=self.tokenizer.pad_token_id)
        labels = torch.nn.utils.rnn.pad_sequence(labels,
                                                 batch_first=True,
                                                 padding_value=IGNORE_INDEX)

        if self.pad_to_max_len:
            input_ids = input_ids[:-1]
            labels = labels[:-1]

        return dict(
            input_ids=input_ids,
            labels=labels,
            attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
        )


@dataclass
class DataCollatorForSupervisedIterableDataset(object):
    """Collate examples for supervised fine-tuning."""

    tokenizer: transformers.PreTrainedTokenizer
    pad_to_max_len: bool

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        input_ids, labels = (
            [instance[0] for instance in instances],
            [instance[1] for instance in instances],
        )

        if self.pad_to_max_len:
            input_ids.append(
                torch.zeros([self.tokenizer.model_max_length],
                            dtype=input_ids[0].dtype))
            labels.append(
                torch.zeros([self.tokenizer.model_max_length],
                            dtype=labels[0].dtype))

        input_ids = torch.nn.utils.rnn.pad_sequence(
            input_ids,
            batch_first=True,
            padding_value=self.tokenizer.pad_token_id)
        labels = torch.nn.utils.rnn.pad_sequence(labels,
                                                 batch_first=True,
                                                 padding_value=IGNORE_INDEX)

        if self.pad_to_max_len:
            input_ids = input_ids[:-1]
            labels = labels[:-1]

        t = ((input_ids, input_ids.ne(self.tokenizer.pad_token_id)), labels)
        return t


def preprocess(sources: Sequence[str],
               targets: Sequence[str],
               tokenizer: transformers.PreTrainedTokenizer,
               use_pipeline: bool = False) -> Dict:
    """Preprocess the data by tokenizing."""
    examples = [s + t for s, t in zip(sources, targets)]
    examples_tokenized, sources_tokenized = [
        _tokenize_fn(strings, tokenizer, use_pipeline)
        for strings in (examples, sources)
    ]
    input_ids = examples_tokenized["input_ids"]
    labels = copy.deepcopy(input_ids)
    for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
        label[:source_len] = IGNORE_INDEX
    return dict(
        input_ids=input_ids,
        labels=labels)  #, dict(input_ids=eval_input_ids, labels=eval_labels)


class SupervisedDataset(Dataset):
    """Dataset for supervised fine-tuning."""

    def __init__(self, data_dict: dict):
        super(SupervisedDataset, self).__init__()
        self.input_ids = data_dict["input_ids"]
        self.labels = data_dict["labels"]

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        return dict(input_ids=self.input_ids[i], labels=self.labels[i])


class SupervisedIterableDataset(IterableDataset):
    """Dataset for supervised fine-tuning."""

    # __len__ is not supported, since it does not make sense in DS-PP.

    def __init__(self, data_dict: dict):
        super(SupervisedIterableDataset, self).__init__()
        self.input_ids = data_dict["input_ids"]
        self.labels = data_dict["labels"]

    def __iter__(self):
        for i, x in enumerate(self.input_ids):
            yield (x, self.labels[i])


def get_raw_data(tokenizer: transformers.PreTrainedTokenizer,
                 data_path,
                 eval_num=0):
    logging.warning("Loading data...")
    data = load_dataset("json", data_files=data_path)

    if eval_num > 0:
        train_val = data["train"].train_test_split(test_size=eval_num,
                                                   shuffle=True,
                                                   seed=42)
        train_data = train_val["train"]
        eval_data = train_val["test"]
    else:
        train_data = data["train"].shuffle(seed=42)
        eval_data = None

    logging.warning("Formatting inputs...")

    if "instruction" in train_data[0]:
        assert False
    else:
        prompt_belle = PROMPT_DICT["prompt_belle"]
        train_sources = [
            prompt_belle.format_map(example) for example in train_data
        ]
        train_targets = [
            f"{tokenizer.bos_token}{example['Assistant']}{tokenizer.eos_token}"
            for example in train_data
        ]

        eval_sources, eval_targets = [], []
        if eval_data is not None:
            eval_sources = [
                prompt_belle.format_map(example) for example in eval_data
            ]
            eval_targets = [
                f"{tokenizer.bos_token}{example['Assistant']}{tokenizer.eos_token}"
                for example in eval_data
            ]

    return train_sources, train_targets, eval_sources, eval_targets


def make_train_eval_dataset(tokenizer: transformers.PreTrainedTokenizer,
                            data_path,
                            training_args,
                            eval_num=0):
    use_pipeline = training_args.use_pipeline
    if data_path is not None:
        print(f"load data from data_path")
        train_sources, train_targets, eval_sources, eval_targets = get_raw_data(
            tokenizer, data_path, eval_num)

        logging.warning("Tokenizing inputs... This may take some time...")
        train_data_dict = preprocess(train_sources, train_targets, tokenizer,
                                     use_pipeline)
        eval_data_dict = preprocess(eval_sources, eval_targets, tokenizer,
                                    use_pipeline)

        if not use_pipeline:
            return SupervisedDataset(train_data_dict), SupervisedDataset(
                eval_data_dict) if eval_num > 0 else None
        else:
            return SupervisedIterableDataset(
                train_data_dict), SupervisedIterableDataset(
                    eval_data_dict) if eval_num > 0 else None

    elif training_args.pretokenized_train_data_path != "":
        print(f"load data from train_dict")
        train_data_dict = np.load(training_args.pretokenized_train_data_path,
                                  allow_pickle=True).item()
        eval_data_dict = None
        if training_args.pretokenized_eval_data_path != "":
            eval_data_dict = np.load(training_args.pretokenized_eval_data_path,
                                     allow_pickle=True).item()

        if not use_pipeline:
            return SupervisedDataset(train_data_dict), SupervisedDataset(
                eval_data_dict) if eval_data_dict is not None else None
        else:
            return SupervisedIterableDataset(
                train_data_dict), SupervisedIterableDataset(
                    eval_data_dict) if eval_data_dict is not None else None

    else:
        return None, None


def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer,
                                data_args, training_args) -> Dict:
    print('make_supervised_data_module for belle')
    train_dataset, eval_dataset = make_train_eval_dataset(
        tokenizer=tokenizer,
        data_path=data_args.data_path,
        training_args=training_args,
        eval_num=data_args.eval_num)

    print(f"using pad_to_max_len {training_args.pad_to_max_len}")

    if not training_args.use_pipeline:
        data_collator = DataCollatorForSupervisedDataset(
            tokenizer=tokenizer, pad_to_max_len=training_args.pad_to_max_len)
    else:
        data_collator = DataCollatorForSupervisedIterableDataset(
            tokenizer=tokenizer, pad_to_max_len=training_args.pad_to_max_len)
    return dict(train_dataset=train_dataset,
                eval_dataset=eval_dataset,
                data_collator=data_collator)
