# https://huggingface.co/docs/transformers/en/training
import os
import json
import pickle
import argparse
from pathlib import Path
from functools import partial

import requests
import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.attention import SDPBackend
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoConfig,
    Trainer,
    TrainingArguments,
    DataCollatorForLanguageModeling
)
from accelerate import Accelerator

from models import AutoModelForCausalLM, LlamaDraftForCausalLM, LlamaForCausalLM
from models.token import KVCache
from utils import Timer
from utils import distance_to_next_zero_1d


XXX_TOKEN = "[XXXSPECIALXXX]"


def get_tokenizer(model_name, padding_side="right", max_length=2048):
    tokenizer = AutoTokenizer.from_pretrained(
        model_name,
        padding_side=padding_side,
        max_length=max_length,
    )

    tokenizer.pad_token_id = tokenizer.eos_token_id
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.add_special_tokens({"additional_special_tokens": [XXX_TOKEN]}, replace_additional_special_tokens=False)

    mock_token_id = tokenizer(
        XXX_TOKEN,
        add_special_tokens=False,
    )["input_ids"]
    assert len(mock_token_id) == 1, f"{mock_token_id} {len(mock_token_id)}"

    return tokenizer


def preprocess_function(examples, tokenizer, max_length=2048):
    new_examples = {
        "conversation":[],
        "input_ids": [],
        "loss_mask": [],
        "length": [],
    }
    id_key = "id" if "id" in examples.keys() else "prompt_id"
    conversation_key = "conversations" if "conversations" in examples.keys() else "messages"
    value_key = "value" if "value" in examples[conversation_key][0][0] else "content"
    for i in range(len(examples[id_key])):
        messages = [
            {"role": "system",
                "content": "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe.  Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."},
        ]
        convroles=["user","assistant"]
        roles = {"human": "user", "gpt": "assistant"}
        source= examples[conversation_key][i]

        def get_role(s):
            if "from" in s:
                return roles[s["from"]]
            else:
                return s["role"]

        if get_role(source[0]) != "user":
            # Skip the first one if it is not from human
            source = source[1:]

        for j, sentence in enumerate(source):
            role = get_role(sentence)

            assert role == convroles[j % 2], f"{i}"

            if role == "assistant":
                sentence[value_key] = " " + sentence[value_key]

            messages.append(
                {"role": role, "content": sentence[value_key]}
            )

        record_input_ids = []
        def to_xxx(s):
            input_ids = tokenizer(s, add_special_tokens=False)["input_ids"]
            record_input_ids.append(input_ids)

            return XXX_TOKEN

        messages = [
            {
                "role": m["role"],
                "content": to_xxx(m["content"]) if m["role"] == "assistant" else m["content"]
            }
            for m in messages
        ]

        conversation = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=False,
            tokenizer_kwargs={"add_special_tokens": False},
        )

        pre_input_ids = tokenizer(
            conversation,
            return_tensors="pt",
            max_length=max_length,
            add_special_tokens=False,
            truncation=True,
        ).input_ids[0]

        input_ids = []
        loss_mask = []
        xxx_token_id = tokenizer(XXX_TOKEN, add_special_tokens=False)["input_ids"][0]
        for tok in pre_input_ids:
            if tok == xxx_token_id:
                recorded_input_ids = record_input_ids.pop(0)
                input_ids.extend(recorded_input_ids)
                loss_mask.extend([True] * len(recorded_input_ids))
            else:
                input_ids.append(tok)
                loss_mask.append(False)

        input_ids = torch.tensor(input_ids)
        loss_mask = torch.tensor(loss_mask)

        assert (
            input_ids.shape[0] == loss_mask.shape[0]
        ), f"{input_ids.shape[0]} != {loss_mask.shape[0]}"

        if loss_mask.sum() == 0:
            continue
        new_examples["conversation"].append(conversation)
        new_examples["input_ids"].append(input_ids[None,:max_length])
        new_examples["length"].append(len(input_ids))
        new_examples["loss_mask"].append(loss_mask[None,:max_length])

    return new_examples


def collate_fn(data):
    batch_size = len(data)
    input_lens = [len(d["input_ids"][0]) for d in data]
    max_input_len = max(input_lens)
    batch_input_ids = torch.zeros(batch_size, max_input_len, dtype=torch.long)
    batch_attention_mask = torch.zeros(batch_size, max_input_len, dtype=torch.long)
    batch_loss_mask = torch.zeros(batch_size, max_input_len, dtype=torch.bool)

    for i, d in enumerate(data):
        input_ids = d["input_ids"][0]
        input_len = len(input_ids)
        batch_input_ids[i, :input_len] = torch.tensor(input_ids)
        batch_attention_mask[i, :input_len] = 1
        batch_loss_mask[i, :input_len] = torch.tensor(d["loss_mask"][0])

    return {
        "batch_input_ids": batch_input_ids,
        "batch_attention_mask": batch_attention_mask,
        "input_ids": [d["input_ids"][0] for d in data],
        "loss_mask": [d["loss_mask"][0] for d in data],
        "batch_loss_mask": batch_loss_mask,
    }


def build_dataset(tokenizer, dataset, max_length=2048,
                  num_proc=16):
    dataset = dataset

    return dataset.map(
        partial(preprocess_function, tokenizer=tokenizer, max_length=max_length),
        batched=True,
        num_proc=num_proc,
        remove_columns=dataset.column_names,
        load_from_cache_file=False,
    )


def sort_by_length(dataset):
    dataset = dataset.sort("length", reverse=True)
    return dataset


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--debug", action="store_true")
    parser.add_argument("--max_length", type=int, default=2048)
    parser.add_argument("--num_proc", type=int, default=16)
    parser.add_argument("--model_type", type=str, default="llama3")
    parser.add_argument("--output_dir", type=Path, default="preprocessed_data")
    parser.add_argument("--batch_size", type=int, default=16)
    parser.add_argument("--dataset_id", choices=["sharegpt", "ultrachat", "json"], default="sharegpt")
    parser.add_argument("--dataset_file", type=str, default="generated_data.json")
    parser.add_argument("--depth", type=int, default=3)
    parser.add_argument("--top_draft", type=int, default=4)
    parser.add_argument("--top_node", type=int, default=16)
    parser.add_argument("--repeat", type=int, default=8)
    args = parser.parse_args()
    return args


@torch.inference_mode()
def generate_data(model, data, max_length=2048, depth=6, top_draft=4, top_node=16):
    past_key_values = KVCache(
        model.config.num_hidden_layers,
        model.config.num_key_value_heads,
        max_length * depth * top_draft,
        model.config.head_dim,
        model.device,
        model.dtype,
        num_seqs=data["batch_input_ids"].shape[0],
    )
    input_ids = data["batch_input_ids"].clone().detach().to(model.device)

    kv_cache_indices = []
    min_value = torch.finfo(model.dtype).min
    max_value = torch.finfo(model.dtype).max
    loss_masks = data["batch_loss_mask"].clone().detach().to(model.device)

    assert input_ids.size(0) == 1

    q_len = loss_masks.size(1)
    q_len_valid = loss_masks.long().sum().item()
    valid_mask = loss_masks.bool().flatten()

    last_input_ids = input_ids
    last_position_ids = torch.arange(0, q_len, device=model.device).unsqueeze(0)
    last_output_score = torch.zeros((q_len, 1), device=model.device, dtype=model.dtype)
    last_attention_mask = torch.full(
        (1, 1, q_len, q_len),
        min_value,
        device=model.device,
        dtype=model.dtype,
        requires_grad=False,
    )

    last_attention_mask[0, 0] = torch.triu(
        last_attention_mask[0, 0],
        diagonal=1,
    )

    output_topk_ids_history = []
    output_topk_probs_history = []
    sampled_idxs_history = []
    sampled_logits_history = []
    sampled_scores_history = []

    for i in range(depth):
        ql = q_len if i == 0 else q_len_valid
        nc = top_draft if i > 0 else 1

        new_kv_cache_indices = past_key_values.allocate(ql * nc)
        kv_cache_indices.extend(new_kv_cache_indices)

        model_inputs = {}
        model_inputs["input_ids"] = last_input_ids
        model_inputs["position_ids"] = last_position_ids
        model_inputs["attention_mask"] = last_attention_mask
        model_inputs["past_key_values"] = past_key_values
        model_inputs["past_key_value_indices"] = kv_cache_indices
        model_inputs["use_cache"] = True
        model_inputs["output_hidden_states"] = True

        output = model(**model_inputs)

        if i == 0:
            num_hidden_layers = model.config.num_hidden_layers
            # hidden_state = torch.stack(output["hidden_states"], dim=1)[:, [2, num_hidden_layers // 2, -2]].detach().clone()
            # assert hidden_state.size() == (1, 3, q_len, model.config.hidden_size)
            hidden_state = output["hidden_states"][-1].detach().clone()
            assert hidden_state.size() == (1, q_len, model.config.hidden_size)

        v = output["logits"].size(-1)
        output_logits = output["logits"].log_softmax(-1).reshape(ql, nc * v)
        assert output_logits.size() == (ql, nc * v)
        assert last_output_score.size() == (ql, nc)
        last_output_score = last_output_score.reshape(ql, nc, 1) + output_logits.reshape(ql, nc, v)
        last_output_score = last_output_score.reshape(ql, nc * v)

        output_ids = output_logits.reshape(ql, nc, v)[:, 0].argmax(-1).unsqueeze(-1)
        assert output_ids.size() == (ql, 1)

        last_output_score_for_sampling = last_output_score.scatter(
            -1,
            output_ids,
            max_value,
        )

        sampled_idxs = torch.topk(
            last_output_score_for_sampling,
            top_draft,
            dim=-1,
        ).indices

        sampled_node_idxs = torch.topk(
            last_output_score_for_sampling,
            top_node,
            dim=-1,
        ).indices
        assert sampled_node_idxs.size() == (ql, top_node)

        output_topk_ids = output_logits.reshape(ql, nc, v)[:, 0].topk(top_node, dim=-1).indices
        output_topk_probs = output_logits.reshape(ql, nc, v)[:, 0].exp().topk(top_node, dim=-1).values

        sampled_node_logits = output_logits.gather(-1, sampled_node_idxs)
        assert sampled_node_logits.size() == (ql, top_node)

        sampled_idxs = sampled_idxs.reshape(1, ql, top_draft)
        assert sampled_idxs.size() == (1, ql, top_draft)
        sampled_p_ids = sampled_idxs % v
        assert sampled_p_ids.size() == (1, ql, top_draft)
        sampled_p_idxs = sampled_idxs // v
        assert sampled_idxs.size() == (1, ql, top_draft)

        last_input_ids = sampled_p_ids.reshape(1, ql * top_draft)
        assert last_input_ids.size() == (1, ql * top_draft)

        last_position_ids = last_position_ids + 1
        assert last_position_ids.size() == (1, ql * nc)
        last_position_ids = last_position_ids.reshape(ql, nc).gather(
            1, sampled_p_idxs.reshape(ql, top_draft)
        )
        assert last_position_ids.size() == (ql, top_draft)
        last_position_ids = last_position_ids.reshape(1, ql * top_draft)
        assert last_position_ids.size() == (1, ql * top_draft)

        last_output_score = last_output_score.reshape(1, ql, nc * v).gather(-1, sampled_idxs)
        assert last_output_score.size() == (1, ql, top_draft)
        last_output_score = last_output_score.reshape(ql, top_draft)
        assert last_output_score.size() == (ql, top_draft)

        if i == 0:
            assert last_input_ids.size() == (1, q_len * top_draft)
            last_input_ids = last_input_ids.reshape(1, q_len, top_draft)
            last_input_ids = last_input_ids[:, valid_mask]
            assert last_input_ids.size() == (1, q_len_valid, top_draft)
            last_input_ids = last_input_ids.reshape(1, -1)
            assert last_input_ids.size() == (1, q_len_valid * top_draft)

            assert last_attention_mask.size() == (1, 1, q_len, q_len)
            last_attention_mask = last_attention_mask.reshape(1, 1, q_len, q_len)
            last_attention_mask = last_attention_mask[:, :, valid_mask]
            assert last_attention_mask.size() == (1, 1, q_len_valid, q_len)

            assert last_position_ids.size() == (1, q_len * top_draft)
            last_position_ids = last_position_ids.reshape(1, q_len, top_draft)
            last_position_ids = last_position_ids[:, valid_mask]
            assert last_position_ids.size() == (1, q_len_valid, top_draft)
            last_position_ids = last_position_ids.reshape(1, -1)
            assert last_position_ids.size() == (1, q_len_valid * top_draft)

            assert last_output_score.size() == (q_len, top_draft)
            last_output_score = last_output_score[valid_mask]
            assert last_output_score.size() == (q_len_valid, top_draft)

            assert sampled_node_idxs.size() == (q_len, top_node)
            sampled_node_idxs = sampled_node_idxs[valid_mask]
            assert sampled_node_idxs.size() == (q_len_valid, top_node)

            assert sampled_node_logits.size() == (q_len, top_node)
            sampled_node_logits = sampled_node_logits[valid_mask]
            assert sampled_node_logits.size() == (q_len_valid, top_node)

            assert output_topk_ids.size() == (q_len, top_node)
            output_topk_ids = output_topk_ids[valid_mask]
            assert output_topk_ids.size() == (q_len_valid, top_node)

            assert output_topk_probs.size() == (q_len, top_node)
            output_topk_probs = output_topk_probs[valid_mask]
            assert output_topk_probs.size() == (q_len_valid, top_node)

            assert sampled_p_idxs.size() == (1, q_len, top_draft)
            sampled_p_idxs = sampled_p_idxs[:, valid_mask]
            assert sampled_p_idxs.size() == (1, q_len_valid, top_draft)

        past_kv_len = last_attention_mask.size(-1)
        assert last_attention_mask.size() == (1, 1, q_len_valid * nc, past_kv_len)
        last_attention_mask = last_attention_mask.reshape(q_len_valid, nc, past_kv_len).gather(
            1, sampled_p_idxs.reshape(q_len_valid, top_draft, 1).expand(q_len_valid, top_draft, past_kv_len)
        )
        assert last_attention_mask.size() == (q_len_valid, top_draft, past_kv_len)
        last_attention_mask = last_attention_mask.reshape(
            1, 1, q_len_valid * top_draft, past_kv_len,
        )
        assert last_attention_mask.size() == (1, 1, q_len_valid * top_draft, past_kv_len)
        new_attention_mask = torch.full(
            (1, 1, q_len_valid * top_draft, q_len_valid * top_draft),
            min_value,
            device=model.device,
            dtype=model.dtype,
        )
        new_attention_mask[0, 0].diagonal(0).fill_(0)
        assert new_attention_mask.size() == (1, 1, q_len_valid * top_draft, q_len_valid * top_draft)
        last_attention_mask = torch.cat([
            last_attention_mask,
            new_attention_mask,
        ], dim=-1).requires_grad_(False)
        assert last_attention_mask.size() == (1, 1, q_len_valid * top_draft, past_kv_len + q_len_valid * top_draft)

        sampled_logits_history.append(sampled_node_logits)
        sampled_idxs_history.append(sampled_node_idxs)
        output_topk_ids_history.append(output_topk_ids)
        output_topk_probs_history.append(output_topk_probs)
        sampled_scores_history.append(last_output_score)

    past_key_values.free(kv_cache_indices)

    sampled_logits_history = torch.stack(sampled_logits_history, dim=0)
    assert sampled_logits_history.size() == (depth, q_len_valid, top_node)
    sampled_idxs_history = torch.stack(sampled_idxs_history, dim=0)
    assert sampled_idxs_history.size() == (depth, q_len_valid, top_node)

    output_topk_ids_history = torch.stack(output_topk_ids_history, dim=0)
    assert output_topk_ids_history.size() == (depth, q_len_valid, top_node)
    output_topk_probs_history = torch.stack(output_topk_probs_history, dim=0)
    assert output_topk_probs_history.size() == (depth, q_len_valid, top_node)

    sampled_scores_history = torch.stack(sampled_scores_history, dim=0)
    assert sampled_scores_history.size() == (depth, q_len_valid, top_draft)

    return {
        "input_ids": data["input_ids"],
        "hidden_state": hidden_state,
        "loss_mask": data["loss_mask"],
        "sample_idxs": sampled_idxs_history.unsqueeze(0),
        "sample_logits": sampled_logits_history.unsqueeze(0),
        "output_topk_ids": output_topk_ids_history.unsqueeze(0),
        "output_topk_probs": output_topk_probs_history.unsqueeze(0),
        "sampled_scores": sampled_scores_history.unsqueeze(0),
    }


if __name__ == "__main__":
    args = parse_args()
    if args.debug:
        import shy
        shy.err_hook()

    if args.model_type == "llama2-7b":
        model_name = "meta-llama/Llama-2-7b-chat-hf"
    elif args.model_type == "llama2-13b":
        model_name = "meta-llama/Llama-2-13b-chat-hf"
    elif args.model_type == "llama3_1-8b":
        model_name = "meta-llama/Llama-3.1-8B-Instruct"
    elif args.model_type == "r1-distill-qwen-1_5b":
        model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
    elif args.model_type == "r1-distill-qwen-7b":
        model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
    elif args.model_type == "r1-distill-llama-8b":
        model_name = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
    else:
        raise ValueError(f"Unknown model type: {args.model_type}")

    with Timer("Loading tokenizer..."):
        tokenizer = get_tokenizer(model_name, max_length=args.max_length)

    with Timer("Loading dataset..."):
        if args.dataset_id == "sharegpt":
            train_dataset = load_dataset(
                "Aeala/ShareGPT_Vicuna_unfiltered",
                data_files={
                    "train": "ShareGPT_V4.3_unfiltered_cleaned_split.json",
                },
                split="train[:95%]" if not args.debug else "train[:1%]",
            )
        elif args.dataset_id == "ultrachat":
            train_dataset = load_dataset(
                "HuggingFaceH4/ultrachat_200k",
                split="train_sft+train_gen" if not args.debug else "train_sft[:1%]",
            )
        elif args.dataset_id == "json":
            train_dataset = load_dataset(
                "json",
                data_files={
                    "train": args.dataset_file,
                },
                split="train[:95%]" if not args.debug else "train[:1%]",
            )
        else:
            raise ValueError(f"Unknown dataset id: {args.dataset_id}")

    with Timer("Loading dataset..."):
        if args.dataset_id == "sharegpt":
            valid_dataset = load_dataset(
                "Aeala/ShareGPT_Vicuna_unfiltered",
                data_files={
                    "valid": "ShareGPT_V4.3_unfiltered_cleaned_split.json",
                },
                split="valid[95%:]" if not args.debug else "valid[99%:]",
            )
        elif args.dataset_id == "ultrachat":
            valid_dataset = load_dataset(
                "HuggingFaceH4/ultrachat_200k",
                split="test_sft+test_gen" if not args.debug else "test_sft[:1%]",
            )
        elif args.dataset_id == "json":
            valid_dataset = load_dataset(
                "json",
                data_files={
                    "valid": args.dataset_file,
                },
                split="valid[95%:]" if not args.debug else "valid[99%:]",
            )
        else:
            raise ValueError(f"Unknown dataset id: {args.dataset_id}")

    accelerator = Accelerator()
    num_procs = accelerator.num_processes
    proc_idx = accelerator.process_index

    print(f"num_procs: {num_procs}, proc_idx: {proc_idx}")
    print(f"train_dataset: {len(train_dataset)}, valid_dataset: {len(valid_dataset)}")

    train_dataset = train_dataset.select(range(proc_idx, len(train_dataset), num_procs))
    valid_dataset = valid_dataset.select(range(proc_idx, len(valid_dataset), num_procs))
    train_dataset = build_dataset(tokenizer, train_dataset, max_length=args.max_length, num_proc=args.num_proc)
    valid_dataset = build_dataset(tokenizer, valid_dataset, max_length=args.max_length, num_proc=args.num_proc)
    train_dataset = sort_by_length(train_dataset)
    valid_dataset = sort_by_length(valid_dataset)

    model = LlamaForCausalLM.from_pretrained(
        model_name,
        torch_dtype="auto",
        device_map="cuda",
    )
    model = model.eval()

    args.output_dir.mkdir(exist_ok=True, parents=True)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        collate_fn=collate_fn,
        num_workers=args.num_proc if args.num_proc is not None else 0,
    )

    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        collate_fn=collate_fn,
        num_workers=args.num_proc if args.num_proc is not None else 0,
    )

    loader = {"train": train_loader, "valid": valid_loader}
    for split, data_loader in loader.items():
        print(f"{split} dataset: {len(data_loader.dataset)}")

        for idx, data in tqdm(enumerate(data_loader), desc=f"Generating {split} data", total=len(data_loader.dataset) // args.batch_size):
            gen_data = generate_data(
                model,
                data,
                max_length=args.max_length,
                depth=args.depth,
                top_draft=args.top_draft,
            )
            id = (idx * num_procs + proc_idx) * args.batch_size
            bs = len(gen_data["input_ids"])
            for i in range(bs):
                batch_id = id + i
                batch_data = {
                    "input_ids": gen_data["input_ids"][i],
                    "hidden_state": gen_data["hidden_state"][i].cpu().clone().detach(),
                    "loss_mask": gen_data["loss_mask"][i],
                    "sample_idxs": gen_data["sample_idxs"][i].cpu().clone().detach(),
                    "sample_logits": gen_data["sample_logits"][i].cpu().clone().detach(),
                    "output_topk_ids": gen_data["output_topk_ids"][i].cpu().clone().detach(),
                    "output_topk_probs": gen_data["output_topk_probs"][i].cpu().clone().detach(),
                    "sampled_scores": gen_data["sampled_scores"][i].cpu().clone().detach(),
                }
                pickle.dump(batch_data, (args.output_dir / f"{split}_{batch_id}.pkl").open("wb"))

    accelerator.wait_for_everyone()
    accelerator.print("All done!")
