# PG19 long qa which was generated from
# - Scaling Instruction-tuned LLMs to million token contexts via hierarchical synthetic data generation
# - https://openreview.net/pdf?id=BkwCrIsTbR)

import json
import os
import random
import re
from typing import Literal

import datasets
import torch
from torch.utils.data import Dataset
from tqdm import tqdm


class PG19LongQA(Dataset):
    def __init__(
        self,
        tokenizer,
        path: str,
        split: Literal["train", "validation", "test"] = "train",
    ):
        self.tokenizer = tokenizer

        cache_path = "saves/cache/pg19-longqa"

        # the following is specific to all llama 3.x models which use the same
        # tokenizer
        if "llama-3." not in tokenizer.name_or_path.lower():
            raise NotImplementedError(
                f"{tokenizer.name_or_path=} not implemented for PG19 LongQA"
            )

        name = tokenizer.name_or_path.replace("/", "-")
        name = re.sub(r"3\.\d+", "3.x", name)
        n = name.split("-")
        name = "-".join(n[:4])

        os.makedirs(cache_path, exist_ok=True)

        filename = os.path.join(cache_path, f"{name}-{split}.pt")
        if os.path.exists(filename):
            data, labels = torch.load(filename)
            self.data = data
            self.labels = labels
            return

        p = os.path.join(path, "pg19", split)
        files = os.listdir(p)

        data, lbls = [], []
        for _f in files:
            with open(os.path.join(p, _f), "r") as f:
                print(f"opening: {_f}")
                for line in f:
                    d = json.loads(line)

                    tokens = []
                    labels = []
                    for i, turn in enumerate(d["conversations"]):

                        if type(turn["content"]) == str:
                            if "Error generating" in turn["content"]:
                                continue

                        tokenized = tokenizer.apply_chat_template(
                            [turn], return_tensors="pt", truncate=False
                        )
                        # if this is the first round in the conversation, add the system tokens too
                        if i == 0:
                            tokens.append(tokenized)
                            labels.append(torch.full_like(tokenized, -100))
                            continue

                        # skip the system tokens which have the cutoff date
                        no_sys_tokens = tokenized[:, 26:]
                        # print(f"{no_sys_tokens=}")
                        # print(f"{no_sys_tokens.ne(tokenizer.pad_token_id)=}")

                        # print(f"{tokenizer.decode(no_sys_tokens[0])=}")
                        tokens.append(no_sys_tokens)

                        begin = tokenizer(
                            "<|start_header_id|>assistant<|end_header_id|>",
                            return_tensors="pt",
                        )["input_ids"][:, 1:]

                        if torch.all(no_sys_tokens[:, :3] == begin):
                            lbl = no_sys_tokens.clone()
                            lbl[:, :3] = -100
                            labels.append(lbl)
                            continue

                        labels.append(torch.full_like(no_sys_tokens, -100))

                    tokens = torch.cat(tokens, dim=-1)
                    labels = torch.cat(labels, dim=-1)

                    data.append(tokens)
                    lbls.append(labels)

        torch.save((data, lbls), filename)
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        x, y = self.data[idx], self.labels[idx]
        return x[0], y[0]


if __name__ == "__main__":
    import transformers

    tokenizer = transformers.AutoTokenizer.from_pretrained(
        "meta-llama/Llama-3.2-1B-Instruct"
    )

    for split in ["train", "validation", "test"]:
        ds = PG19LongQA(
            tokenizer, "/d1/dataset/pg19-hierarchical-qa", split=split)
        print(f"running: {split} {len(ds)=}")

        lengths, lbl = [], []
        for ids, labels in tqdm(ds, total=len(ds)):
            lengths.append(ids.shape[-1])
            # print(f"{type(ids)=} {type(labels)=} {ids.shape=} {labels.shape=}")
            lbl.append((labels >= 0).sum().item())

        print(f"max: {max(lengths)=} {min(lbl)=} {max(lbl)=}")
