"""
Split long conversations based on certain max length.

Usage: python3 -m fastchat.data.split_long_conversation \
    --in sharegpt_clean.json \
    --out sharegpt_split.json \
    --model-name-or-path $<model-name>
"""
import argparse
import json
from typing import Dict, Sequence, Optional

import transformers
import tqdm

from llava import conversation as conversation_lib

DEFAULT_PAD_TOKEN = "[PAD]"
BEGIN_SIGNAL = "### "
END_SIGNAL = "\n"


def split_sample(sample, start_idx, end_idx):
    # only ends in the bot because otherwise the last human part is useless.
    end_speaker = sample["conversations"][end_idx]["from"]
    end_idx = end_idx + 1 if end_speaker != "human" else end_idx
    return {
        "id": sample["id"] + "_" + str(start_idx),
        "conversations": sample["conversations"][start_idx:end_idx]
    }


def split_contents(content, begin, end, tokenizer, max_length):
    """
    Keep the maximum round of conversations within the max token length constraint
    """
    content = content[begin:end]
    new_content = []

    for sample in tqdm.tqdm(content):
        tokenized_lens = []

        for c in sample["conversations"]:
            from_str = c["from"]
            if from_str.lower() == "human":
                from_str = conversation_lib.default_conversation.roles[0]
            elif from_str.lower() == "gpt":
                from_str = conversation_lib.default_conversation.roles[1]
            else:
                from_str = 'unknown'

            sentence = (BEGIN_SIGNAL + from_str + ": " + c["value"] +
                        END_SIGNAL)
            length = tokenizer(sentence, return_tensors="pt", padding="longest"
                ).input_ids.ne(tokenizer.pad_token_id).sum().item()
            tokenized_lens.append(length)

        num_tokens = 0
        start_idx = 0
        for idx, l in enumerate(tokenized_lens):
            # TODO: shall we also only starts from a specific speaker?
            if num_tokens + l > max_length:
                new_content.append(split_sample(sample, start_idx, idx))
                start_idx = idx
                num_tokens = l
            else:
                num_tokens += l
                if idx == len(tokenized_lens) - 1:
                    new_content.append(split_sample(sample, start_idx, idx))

    print(f"total: {len(content)}, new: {len(new_content)}")
    return new_content


def main(args):
    content = json.load(open(args.in_file, "r"))
    tokenizer = transformers.AutoTokenizer.from_pretrained(
        args.model_name_or_path,
        model_max_length=args.max_length,
        padding_side="right",
        use_fast=False,
    )
    if tokenizer.pad_token is None:
        tokenizer.add_special_tokens(dict(pad_token=DEFAULT_PAD_TOKEN))
    content = split_contents(content, args.begin, args.end,
        tokenizer, args.max_length)
    json.dump(content, open(args.out_file, "w"), indent=2)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--in-file", type=str, required=True)
    parser.add_argument("--out-file", type=str, default="sharegpt_split.json")
    parser.add_argument("--begin", type=int)
    parser.add_argument("--end", type=int)
    parser.add_argument("--model-name-or-path", type=str, required=True)
    parser.add_argument("--max-length", type=int, default=2304)
    args = parser.parse_args()
    main(args)
