#!/usr/bin/env python

from datetime import datetime

from datasets import concatenate_datasets, disable_caching, load_dataset
from transformers import set_seed

from sacremoses import MosesTokenizer
from llm import ParallelCaptionizer

MIN_CAPTION_LENGTH = 30
MODELS = {"chatgpt", "gpt4"}
TEST_EXCLUDE = {
    "https://arxiv.org/abs/2303.12712"  # contains GPT-4 generated tikzpictures
}

def is_test_candidate(ex, cutoff=datetime(2022, 12, 1)):
    """
    Returns True for human-generated examples newer than llama
    """
    return not ex['origin'] in MODELS and not ex['uri'] in TEST_EXCLUDE and ex['date'] >= cutoff

def train_test_split(dataset, test_size=1000):
    if len(cand := dataset.filter(is_test_candidate)):
        cand = cand.add_column("labels", cand.class_encode_column("origin")['origin']).class_encode_column("labels")
        remainder, test = cand.train_test_split(test_size=test_size, stratify_by_column="labels").values()

        no_cand = dataset.filter(lambda ex: not is_test_candidate(ex))
        train = concatenate_datasets([no_cand, remainder.remove_columns("labels")])

        return train, test.remove_columns("labels")
    return dataset, cand

def concat(caption, description):
    caption, description = caption.strip(), description.replace("\n", " ").strip()
    if caption:
        caption = caption[0].upper() + caption[1:]
        caption = caption if caption[-1] in ".!?"  else caption + "."
        return " ".join([caption, description]).strip()
    return description

if __name__ == "__main__":
    set_seed(0)
    disable_caching()
    datikz = load_dataset("datikz", split="train")
    tokenize, captionize = MosesTokenizer().tokenize, ParallelCaptionizer()

    # retain tikzpictures generated by models in both "good" and "bad" subset
    good = datikz.filter(lambda ex: ex['origin'] in MODELS or len(tokenize(ex['caption'])) >= MIN_CAPTION_LENGTH)
    bad = datikz.filter(lambda ex: ex['origin'] in MODELS or len(tokenize(ex['caption'])) < MIN_CAPTION_LENGTH)

    generated_caps = list(map(lambda tup: concat(*tup), zip(bad['caption'], captionize(bad['image']))))  # type: ignore
    augmented = bad.remove_columns("caption").add_column("caption", generated_caps) # type: ignore

    good = good.add_column("augmented", len(good) * [False]) # type: ignore
    augmented = augmented.add_column("augmented", len(augmented) * [True]) # type: ignore

    combined_caps = concatenate_datasets([good, augmented]) # type: ignore
    train, test = train_test_split(combined_caps)

    datikz.to_parquet("datikz-raw.parquet", compression="GZIP") # type: ignore
    train.to_parquet("datikz-train.parquet", compression="GZIP") # type: ignore
    test.to_parquet("datikz-test.parquet", compression="GZIP") # type: ignore
