import random
import pickle
import logging
from typing import Dict, Optional, Sequence
from dataclasses import dataclass, field

import json
import numpy as np
import torch
import transformers
from transformers.trainer_pt_utils import LabelSmoother
from fastchat.model.model_adapter import get_conversation_template


IGNORE_TOKEN_ID = LabelSmoother.ignore_index


@dataclass
class ModelArguments:
    model_name_or_path: Optional[str] = field(default="lmsys/vicuna-7b-v1.3")

    model_max_length: int = field(
        default=512,
        metadata={
            "help": (
                "Maximum sequence length. Sequences will be right padded (and possibly"
                " truncated)."
            )
        },
    )


@dataclass
class DataArguments:
    data_path: str = field(
        metadata={"help": "A JSON file containing the training data."}
    )
    output_data_path: str = field(
        metadata={"help": "The pickle file name to dump the training data."}
    )
    skeleton_prompts: str = field(
        metadata={"help": "A JSON list file containing all the skeleton prompts."}
    )
    point_prompts: str = field(
        metadata={
            "help": "A JSON list file containing all the point-expanding prompts."
        }
    )
    conv_point_prompts: str = field(
        metadata={
            "help": (
                "A JSON list file containing all the conversational point-expanding"
                " prompts."
            )
        }
    )
    add_conv_point_data: bool = False


def preprocess(
    conv_template_func,
    sources,
    skeleton_prompts,
    point_prompts,
    conv_point_prompts,
    tokenizer: transformers.PreTrainedTokenizer,
    add_conv_point_data=False,
) -> Dict:
    conversations = []
    mask_lens = []
    for source in sources:
        # Apply prompt templates
        request, skeleton, num_points, contents = (
            source["request"],
            source["skeleton"],
            source["num_points"],
            source["contents"],
        )

        # skeleton stage
        conv = conv_template_func()
        ts_template = random.choice(skeleton_prompts)
        conv.append_message(conv.roles[0], ts_template.format(request=request))
        conv.append_message(conv.roles[1], None)
        mask_len = len(tokenizer(conv.get_prompt()).input_ids)
        conv.update_last_message(skeleton)

        conversations.append(conv.get_prompt())
        mask_lens.append(mask_len)

        # point expand stage
        for point_index in range(num_points):
            conv = conv_template_func()
            tp_template = random.choice(point_prompts)
            # TODO: check if extracting and formatting point_outline is needed.
            conv.append_message(
                conv.roles[0],
                tp_template.format(
                    request=request, skeleton=skeleton, point=point_index
                ),
            )
            conv.append_message(conv.roles[1], None)
            mask_len = len(tokenizer(conv.get_prompt()).input_ids)
            conv.update_last_message(contents[point_index])

            conversations.append(conv.get_prompt())
            mask_lens.append(mask_len)

        if add_conv_point_data:
            # conversational point expand stage
            for point_index in range(num_points):
                conv = conv_template_func()
                point_index = random.randint(0, num_points - 1)
                ts_template = random.choice(skeleton_prompts)
                tcp_template = random.choice(conv_point_prompts)
                conv.append_message(conv.roles[0], ts_template.format(request=request))
                conv.append_message(conv.roles[1], skeleton)
                conv.append_message(
                    conv.roles[0], tcp_template.format(point=point_index)
                )
                conv.append_message(conv.roles[1], None)
                mask_len = len(tokenizer(conv.get_prompt()).input_ids)
                conv.update_last_message(contents[point_index])

                conversations.append(conv.get_prompt())
                mask_lens.append(mask_len)

    input_ids = tokenizer(
        conversations,
        return_tensors="pt",
        padding="max_length",
        max_length=tokenizer.model_max_length,
        truncation=True,
    ).input_ids
    targets = input_ids.clone()

    for mask_len, target in zip(mask_lens, targets):
        target[:mask_len] = IGNORE_TOKEN_ID

        if False:  # Inspect and check the correctness of masking
            z = target.clone()
            z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z)
            print(tokenizer.decode(z))

    return dict(
        input_ids=input_ids,
        labels=targets,
        attention_mask=input_ids.ne(tokenizer.pad_token_id),
    )


if __name__ == "__main__":
    parser = transformers.HfArgumentParser((ModelArguments, DataArguments))
    model_args, data_args = parser.parse_args_into_dataclasses()
    tokenizer = transformers.AutoTokenizer.from_pretrained(
        model_args.model_name_or_path,
        model_max_length=model_args.model_max_length,
        padding_side="right",
        use_fast=False,
    )
    tokenizer.pad_token = tokenizer.unk_token

    # conversation template
    def get_conversation_template_func(model_path):
        conv_template = get_conversation_template(model_path)
        conv_template.messages = []
        print(f"Conversation template: {conv_template}")

        def _func():
            return conv_template.copy()

        return _func

    conv_template_func = get_conversation_template_func(model_args.model_name_or_path)

    # load the raw data
    raw_data = json.load(open(data_args.data_path, "r"))

    # preprocess: (1) apply prompt and conversation templates
    #             (2) prepare appropriate labels and attention mask
    data = preprocess(
        conv_template_func,
        raw_data,
        data_args.skeleton_prompts,
        data_args.point_prompts,
        data_args.conv_point_prompts,
        tokenizer,
        add_conv_point_data=data_args.add_conv_point_data,
    )

    print("Data shape:", data["input_ids"].shape)

    with open(data_args.output_data_path, "wb") as wf:
        print(f"Pickle dumping the data to {data_args.output_data_path}")
        pickle.dump(data, wf)
