import abc
import logging
import random
import re
import sys
from dataclasses import dataclass, field
from typing import Optional, Dict, Literal

import datasets
import transformers
from transformers import DataCollatorForLanguageModeling, TrainingArguments, AutoTokenizer
from datasets import load_dataset, DatasetDict
from trl import DataCollatorForCompletionOnlyLM

IGNORE_INDEX = -100

_logger = logging.getLogger(__name__)
_logger.addHandler(logging.StreamHandler(sys.stdout))
_logger.setLevel(logging.INFO)


@dataclass
class DataArguments:
    dataset: Literal["oasst1", "unnatural_instruct", "magicoder"] = field(
        default=None,
        metadata={
            "help": "Name of the utilized dataset. Different datasets use different preprocessing steps."
        },
    )
    eval_dataset_size: int = field(
        default=1024, metadata={"help": "Size of validation dataset."}
    )
    max_train_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
            "value if set."
        },
    )
    max_eval_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
            "value if set."
        },
    )
    prompt_format: Literal["guanaco", "chatml"] = field(
        default="guanaco",
        metadata={
            "help": "The prompt template used for instruction tuning. Options: 'chatml', 'guanaco'"
        },
    )
    only_train_on_completions: bool = field(
        default=False,
        metadata={
            "help": "If True, only use the tokens generated by the assistant for loss computation."
        },
    )

class ChatFormat(abc.ABC):
    def __call__(self, messages: Dict) -> str:
        return self.apply_format(messages)

    @abc.abstractmethod
    def apply_format(self, messages: Dict) -> str:
        """format the messages according to the chat format"""
        raise NotImplementedError("Please implement this method")

    @property
    @abc.abstractmethod
    def instruction_template(self) -> str:
        """return the string that signals the start of a user turn"""
        raise NotImplementedError("Please implement this method")

    @property
    @abc.abstractmethod
    def response_template(self) -> str:
        """return the string that signals the start of a user turn"""
        raise NotImplementedError("Please implement this method")


class PlainFormat(ChatFormat):
    def __init__(self, tokenizer, seperator=" "):
        self.tokenizer = tokenizer
        self.seperator = seperator

    def apply_format(self, messages: Dict) -> str:
        text = self.seperator.join(message["content"] for message in messages)
        return self.tokenizer.bos_token + text + self.tokenizer.eos_token

    @property
    def instruction_template(self) -> str:
        raise NotImplementedError(
            "No instruction template defined for PlainFormat. Training on completions only is not supported"
        )

    @property
    def response_template(self) -> str:
        raise NotImplementedError(
            "No response template defined for PlainFormat. Training on completions only is not supported"
        )


class ChatMLFormat(ChatFormat):
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def apply_format(self, messages: Dict) -> str:
        text = ""
        for message in messages:
            text += f"<|im_start|>{message['role']}\n{message['content']}<|im_end|>\n"
        return self.tokenizer.bos_token + text + self.tokenizer.eos_token

    @property
    def instruction_template(self) -> str:
        return "<|im_start|>user\n"

    @property
    def response_template(self) -> str:
        return "<|im_start|>assistant\n"


class GuanacoFormat(ChatFormat):
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def apply_format(self, messages: Dict) -> str:
        text = ""
        for message in messages:
            if message["role"] == "user":
                text += f"### Human: {message['content']} "
            elif message["role"] == "assistant":
                text += f"### Assistant: {message['content']} "
            else:
                raise ValueError(
                    f"This prompt format only supports 'user' and 'assistant' roles"
                )
        return self.tokenizer.bos_token + text[:-1] + self.tokenizer.eos_token

    @property
    def instruction_template(self) -> str:
        return "### Human:"

    @property
    def response_template(self) -> str:
        return "### Assistant:"


class LLama3Format(ChatFormat):
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def apply_format(self, messages: Dict) -> str:
        text = ""
        for message in messages:
            text += f"<|start_header_id|>{message['role']}<|end_header_id|>\n\n{message['content']}<|eot_id|>"
        return self.tokenizer.bos_token + text + self.tokenizer.eos_token

    @property
    def instruction_template(self) -> str:
        return "<|start_header_id|>user<|end_header_id|>"

    @property
    def response_template(self) -> str:
        return "<|start_header_id|>assistant<|end_header_id|>"


def get_chat_format(name: str, tokenizer) -> ChatFormat:
    F = {
        "chatml": ChatMLFormat,
        "guanaco": GuanacoFormat,
        "llama3": LLama3Format,
    }.get(name)
    if F is None:
        raise NotImplementedError(f"Prompt format not supported: {name}")
    return F(tokenizer)


def guanaco_to_messages(
    dataset, human_prompt="### Human:", assistant_prompt="### Assistant:"
):
    def guanaco_text_to_messages(text: str):
        text = text.strip()
        assert text.startswith(human_prompt)
        parts = re.split(rf"({human_prompt}|{assistant_prompt})", text)[1:]

        assert len(parts) % 2 == 0, f"{parts}"
        assert assistant_prompt in parts

        iter_list = iter(parts)
        return [
            {
                "role": "user" if role == human_prompt else "assistant",
                "content": text.strip(),
            }
            for role, text in zip(iter_list, iter_list)
        ]

    return dataset.map(
        lambda x: {"messages": guanaco_text_to_messages(x["text"])},
        remove_columns=["text"],
    )


def unnatural_instruct_to_messages(dataset):
    def create_messages(sample):
        all_samples = sample["instances"][0] + (
            sample["reformulations"][0] if sample["reformulations"][0] else []
        )
        new_samples = {"messages": []}
        for reformulation in all_samples:
            new_samples["messages"].append(
                [
                    dict(role="user", content=reformulation["instruction_with_input"]),
                    dict(role="assistant", content=reformulation["output"]),
                ]
            )
        return new_samples

    return dataset.map(
        create_messages,
        batched=True,
        batch_size=1,
        remove_columns=["instruction", "instances", "reformulations"],
    )

def magicoder_to_messages(dataset):
    def to_messages(query, response):
        return [
            {"role": "user", "content": query},
            {"role": "assistant", "content": response}
        ]

    dataset = dataset.map(
        lambda x: {"messages": to_messages(x["query"], x["response"])}
    )

    return dataset.select_columns(["messages"])


def _load_magicoder_dataset(hf_home=None):
    # Load datasets
    magicoder = load_dataset("ise-uiuc/Magicoder-OSS-Instruct-75K", cache_dir=hf_home)

    magicoder = magicoder['train'].filter(lambda x: x['lang'] == 'python')

    magicoder = magicoder.map(lambda x: {
        'query': x['problem'],
        'response': x['solution']
    })

    magicoder = magicoder.train_test_split(train_size=30000, seed=42, shuffle=True)
    magicoder_train_split = magicoder['train']
    magicoder_eval_split = magicoder['test']

    return DatasetDict({
        'train': magicoder_train_split,
        'eval': magicoder_eval_split
    })


def _load_dataset(name, hf_home=None):
    if name == "oasst1":
        dataset = load_dataset("timdettmers/openassistant-guanaco", cache_dir=hf_home)
    elif name == "unnatural_instruct":
        dataset = load_dataset("mrm8488/unnatural-instructions-full", cache_dir=hf_home)
    elif name == "magicoder":
        dataset = _load_magicoder_dataset(hf_home)
    else:
        raise NotImplementedError(f"Dataset {name} not implemented.")
    return dataset


def prepare_data(
    tokenizer: transformers.PreTrainedTokenizer,
    data_args: DataArguments,
    train_args: TrainingArguments,
    hf_home=None,
) -> dict:
    """Load and preprocess the dataset"""
    dataset = _load_dataset(data_args.dataset, hf_home=hf_home)

    # convert to format-agnostic intermediate representation as a dictionary of messages
    if data_args.dataset == "oasst1":
        dataset = guanaco_to_messages(dataset)

    if data_args.dataset == "unnatural_instruct":
        dataset = unnatural_instruct_to_messages(dataset)

    if data_args.dataset == "magicoder":
        dataset = magicoder_to_messages(dataset)

    # apply the target chat format
    chat_format = get_chat_format(data_args.prompt_format, tokenizer)
    dataset = dataset.map(
        lambda x: {"text": chat_format.apply_format(x["messages"])},
        remove_columns=["messages"],
    )

    # Split train/eval, reduce size
    if train_args.do_eval:
        if "eval" in dataset:
            eval_dataset = dataset["eval"]
        else:
            _logger.info(
                "Splitting train dataset in train and validation according to `eval_dataset_size`"
            )
            dataset = dataset["train"].train_test_split(
                test_size=data_args.eval_dataset_size, shuffle=True
            )
            eval_dataset = dataset["test"]
        if (
            data_args.max_eval_samples is not None
            and len(eval_dataset) > data_args.max_eval_samples
        ):
            eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))

    if train_args.do_train:
        train_dataset = dataset["train"]
        if (
            data_args.max_train_samples is not None
            and len(train_dataset) > data_args.max_train_samples
        ):
            train_dataset = train_dataset.select(range(data_args.max_train_samples))

    if data_args.only_train_on_completions:
        data_collator = DataCollatorForCompletionOnlyLM(
            response_template=chat_format.response_template,
            instruction_template=chat_format.instruction_template,
            tokenizer=tokenizer,
            mlm=False,
            ignore_index=IGNORE_INDEX,
        )
    else:
        data_collator = DataCollatorForLanguageModeling(
            tokenizer=tokenizer,
            mlm=False,
        )

    return dict(
        train_dataset=train_dataset if train_args.do_train else None,
        eval_dataset=eval_dataset if train_args.do_eval else None,
        data_collator=data_collator,
    )
