import copy
import os
import tempfile
import shutil
import json
import logging
from typing import Sequence, Dict, List, Tuple, Union
from dataclasses import dataclass

import pandas as pd
import numpy as np
import torch
import transformers
from datasets import load_dataset, concatenate_datasets, Dataset, load_from_disk

logger = logging.getLogger(__name__)

IGNORE_INDEX = -100
DEFAULT_PAD_TOKEN = "[PAD]"
DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_BOS_TOKEN = "<s>"
DEFAULT_UNK_TOKEN = "<unk>"

ALPACA_PROMPT_DICT = {
    "prompt_input": (
        "Below is an instruction that describes a task, paired with an input that provides further context. "
        "Write a response that appropriately completes the request.\n\n"
        "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response: "
    ),
    "prompt_no_input": (
        "Below is an instruction that describes a task. "
        "Write a response that appropriately completes the request.\n\n"
        "### Instruction:\n{instruction}\n\n### Response: "
    ),
}


def smart_tokenizer_and_embedding_resize(
    tokenizer: transformers.PreTrainedTokenizer,
    model: transformers.PreTrainedModel,
):
    """Resize tokenizer and embedding.

    Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
    """
    special_tokens_dict = dict()
    if tokenizer.pad_token is None:
        tokenizer.pad_token = DEFAULT_PAD_TOKEN
        # special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN
    if tokenizer.eos_token is None:
        special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN
    if tokenizer.bos_token is None:
        special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN
    if tokenizer.unk_token is None:
        special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN

    previous_tokenizer_size = tokenizer.vocab_size
    num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
    if tokenizer.vocab_size != previous_tokenizer_size:
        print(
            "increase embedding size from {} to {}".format(
                previous_tokenizer_size, tokenizer.vocab_size
            )
        )
        model.resize_token_embeddings(len(tokenizer))

    if num_new_tokens > 0:
        input_embeddings = model.get_input_embeddings().weight.data
        output_embeddings = model.get_output_embeddings().weight.data

        input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
            dim=0, keepdim=True
        )
        output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
            dim=0, keepdim=True
        )

        input_embeddings[-num_new_tokens:] = input_embeddings_avg
        output_embeddings[-num_new_tokens:] = output_embeddings_avg


def _tokenize_fn(
    strings: Sequence[str],
    tokenizer: transformers.PreTrainedTokenizer,
    add_eos_token=False,
) -> Dict:
    """Tokenize a list of strings."""
    tokenized_list = [
        tokenizer(
            text,
            return_tensors=None,
            padding="longest",
            max_length=tokenizer.model_max_length,
            truncation=True,
        )
        for text in strings
    ]
    input_ids = []
    for tokenized in tokenized_list:
        one_sent_input_ids = tokenized.input_ids
        if add_eos_token:
            if (
                one_sent_input_ids[-1] != tokenizer.eos_token_id
                and len(one_sent_input_ids) < tokenizer.model_max_length
            ):
                one_sent_input_ids.append(tokenizer.eos_token_id)
        input_ids.append(torch.tensor(one_sent_input_ids))
    labels = input_ids
    input_ids_lens = labels_lens = [
        one_sent_input_ids.ne(tokenizer.pad_token_id).sum().item()
        for one_sent_input_ids in input_ids
    ]
    return dict(
        input_ids=input_ids,
        labels=labels,
        input_ids_lens=input_ids_lens,
        labels_lens=labels_lens,
    )


def preprocess(
    sources: Sequence[str],
    targets: Sequence[str],
    tokenizer: transformers.PreTrainedTokenizer,
) -> Dict:
    """Preprocess the data by tokenizing."""
    examples = [s + t for s, t in zip(sources, targets)]
    examples_tokenized = _tokenize_fn(examples, tokenizer, add_eos_token=True)
    sources_tokenized = _tokenize_fn(sources, tokenizer)
    input_ids = examples_tokenized["input_ids"]
    labels = copy.deepcopy(input_ids)
    for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
        label[:source_len] = IGNORE_INDEX
    return dict(input_ids=input_ids, labels=labels)


class SupervisedDataset(torch.utils.data.Dataset):
    """Dataset for supervised fine-tuning."""

    def __init__(self, dataset: Dataset, tokenizer: transformers.PreTrainedTokenizer):
        super(SupervisedDataset, self).__init__()
        logging.warning("Loading data...")

        sources = [x["input"] for x in dataset]
        targets = [x["output"] for x in dataset]

        data_dict = preprocess(sources, targets, tokenizer)

        self.input_ids = data_dict["input_ids"]
        self.labels = data_dict["labels"]

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        return dict(input_ids=self.input_ids[i], labels=self.labels[i])


@dataclass
class DataCollatorForSupervisedDataset(object):
    """Collate examples for supervised fine-tuning."""

    tokenizer: transformers.PreTrainedTokenizer

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        input_ids, labels = tuple(
            [instance[key] for instance in instances] for key in ("input_ids", "labels")
        )
        input_ids = torch.nn.utils.rnn.pad_sequence(
            input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
        )
        labels = torch.nn.utils.rnn.pad_sequence(
            labels, batch_first=True, padding_value=IGNORE_INDEX
        )
        return dict(
            input_ids=input_ids,
            labels=labels,
            attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
        )


def extract_alpaca_dataset(example):
    if example.get("input", "") != "":
        prompt_format = ALPACA_PROMPT_DICT["prompt_input"]
    else:
        prompt_format = ALPACA_PROMPT_DICT["prompt_no_input"]
    return {"input": prompt_format.format(**example)}


def save_dataset_to_disk(dataset, output_path):
    # due to the limitation of dataset, we need a temporary dir
    # https://discuss.huggingface.co/t/how-do-i-add-things-rows-to-an-already-saved-dataset/27423
    temp_dir = tempfile.TemporaryDirectory()
    print(f"dataset.save_to_disk({temp_dir.name})")
    dataset.save_to_disk(temp_dir.name)

    if os.path.exists(output_path):
        bak_path = "./bak/" + output_path
        if os.path.exists(bak_path):
            print(f"remove old bakup {bak_path}")
            shutil.rmtree(bak_path)
        print(f"backup shutil.copytree({output_path}, {bak_path})")
        shutil.move(output_path, bak_path)

    print(f"shutil.copytree({temp_dir.name}, {output_path})")
    shutil.copytree(temp_dir.name, output_path)
    temp_dir.cleanup()


def load_dataset_by_name(dataset_name) -> Dataset:
    if dataset_name == "alpaca":
        dataset = (
            load_dataset("tatsu-lab/alpaca", split="train")
            .map(extract_alpaca_dataset)
            .filter(lambda example: len(example["output"]) > 0)
        )
        # we need this filter because some examples have empty output

    elif dataset_name == "alpaca-clean":
        dataset = load_dataset("yahma/alpaca-cleaned", split="train").map(
            extract_alpaca_dataset
        )
    # elif dataset_name == 'chip2':
    #     dataset = load_dataset("laion/OIG", data_files='unified_chip2.jsonl')
    elif dataset_name == "self-instruct":
        dataset = load_dataset(
            "yizhongw/self_instruct", split="train", name="self_instruct"
        )
        for old, new in [["prompt", "input"], ["completion", "output"]]:
            dataset = dataset.rename_column(old, new)
    # elif dataset_name == 'hh-rlhf':
    #     dataset = load_dataset("Anthropic/hh-rlhf")
    # elif dataset_name == 'longform':
    #     dataset = load_dataset("akoksal/LongForm")
    elif dataset_name == "oasst1":
        dataset = load_dataset("timdettmers/openassistant-guanaco", split="train")
        dataset = dataset.map(
            lambda x: {
                "input": "",
                "output": x["text"],
            }
        )
    elif dataset_name == "vicuna":
        raise NotImplementedError("Vicuna data was not released.")
    elif dataset_name.endswith(".json"):
        dataset = Dataset.from_json(path_or_paths=dataset_name)
    elif dataset_name.endswith(".jsonl"):
        dataset = Dataset.from_json(filename=dataset_name, format="jsonlines")
    elif dataset_name.endswith(".csv"):
        dataset = Dataset.from_pandas(pd.read_csv(dataset_name))
    elif dataset_name.endswith(".tsv"):
        dataset = Dataset.from_pandas(pd.read_csv(dataset_name, delimiter="\t"))
    else:
        print(f"load_from_disk({dataset_name})")
        dataset = load_from_disk(dataset_name)
    print(dataset)
    return dataset


def create_one_train_dataset(dataset_name: str, quantity: int = -1, seed=None) -> Dataset:
    print("create_one_train_dataset:", dataset_name, quantity)
    train_dataset = load_dataset_by_name(dataset_name)
    if quantity != -1 and quantity < len(train_dataset):
        train_dataset = train_dataset.shuffle(seed=seed).select(range(quantity))
    return train_dataset


def create_merged_train_dataset(dataset_metas: List[Tuple[str, int]], seed=None) -> Dataset:
    train_datasets = []
    for dataset_name, dataset_quantity in dataset_metas:
        train_datasets.append(create_one_train_dataset(dataset_name, dataset_quantity, seed=seed))
    merged_dataset = concatenate_datasets(train_datasets)
    return merged_dataset


def load_dataset_from_data_meta(data_meta: Union[str, List[Tuple[str, int]]], seed=None):
    if isinstance(data_meta, str):
        return create_one_train_dataset(data_meta, seed=seed)
    else:
        return create_merged_train_dataset(data_meta, seed=seed)


def load_dataset_from_data_metastr(data_metastr: str, seed=None) -> Dataset:
    if data_metastr.startswith("[") or data_metastr.startswith('"'):
        data_metas = json.loads(data_metastr)
        return load_dataset_from_data_meta(data_metas, seed=seed)
    else:
        return load_dataset_from_data_meta(data_metastr, seed=seed)


def calculate_dataset_indicators(indicator_name):
    """Given one indicator, calculate all indicator values for sub-datasets under this indicator."""
    home_dir = "data/indicator_selector"
    records = []
    # find all paths
    for path in os.listdir(os.path.join(home_dir, f"{indicator_name}/merged_dataset")):
        # calculate 
        dataset_path = os.path.join(home_dir, f"{indicator_name}/merged_dataset/{path}")
        dataset = load_from_disk(dataset_path)
        
        item_dict = {"dataset_name": path}
        for column in dataset.column_names:
            if column not in ['input', 'output']:
                # Convert the column to numpy for calculation
                column_data = np.array(dataset[column])

                # Skip non-numeric columns
                if np.issubdtype(column_data.dtype, np.number):
                    mean = column_data.mean()
                    item_dict[column] = mean
                else:
                    pass
        records.append(item_dict)
    # beautify output
    result_df = pd.DataFrame.from_records(records).sort_values(indicator_name)
    result_df.to_csv(f"data_meta_info/{indicator_name}.csv")
    print(result_df)


class ListDataset(torch.utils.data.Dataset):
    def __init__(self, original_list):
        self.original_list = original_list

    def __len__(self):
        return len(self.original_list)

    def __getitem__(self, i):
        return self.original_list[i]


if __name__ == "__main__":
    pass
