
import random
from multiprocessing import Pool

import fire
import pandas as pd
import ujson as json
from transformers import AutoTokenizer, set_seed

set_seed(43)

def get_unique_indices(lst):
    seen = set()
    unique_indices = []
    for i in range(len(lst)):
        if lst[i] not in seen:
            seen.add(lst[i])
            unique_indices.append(i)
    return unique_indices

N = 17
tok = AutoTokenizer.from_pretrained("roberta-base")
def clean(elem: dict) -> dict:
    text = elem["syms"]
    action_type = elem["action_type"]

    unique_indices = get_unique_indices(text)
    text = [text[i] for i in unique_indices]
    action_type = [action_type[i] for i in unique_indices]

    valid_indices = [
        i for i, t in enumerate(text) 
        if len(tok(t)["input_ids"]) >= 32 and len(tok(t)["input_ids"]) <= 128
    ]
    
    if len(valid_indices) < N:
        return {}
    
    valid_indices = random.sample(valid_indices, N)
    action_type = [action_type[i] for i in valid_indices]
    text = [text[i] for i in valid_indices]
    new_elem = {
        "author_id": elem["author_id"],
        "content_text": text[0],
        "content_action_type": action_type[0],
        "reference_text": text[1:],
        "reference_action_type": action_type[1:],
    }

    return new_elem
    
def main(
    num_authors: int = 100_000,
    classifier_datapath: str = None,
    debug: bool = False,
):
    if debug:
        num_authors = 100
    
    # find authors in style paraphraser training dataset
    stylepara_train_path = "/FOO/BAR/data/style_transfer/mud_32-128tok_63kauth_16post.jsonl"
    authors_in_train = set(
        [json.loads(line)["author_id"] for line in open(stylepara_train_path, "r")]
    )
    
    if classifier_datapath is not None:
        assert isinstance(classifier_datapath, str)
        print("Reading authors from: {}".format(classifier_datapath))
        authors_in_train.update(
            set([json.loads(line)["author_id"] for line in open(classifier_datapath, "r")])
        )
    print("Found #{} authors in training...".format(len(authors_in_train)))

    MUD_path = "/FOO/BAR/data/raw_all/data.jsonl"
    records = []
    it = 0

    pool = Pool(40)
    df_chunk = pd.read_json(MUD_path, chunksize=1000, lines=True)
    for chunk in df_chunk:
        it += len(chunk)
        chunk = chunk[~chunk["author_id"].isin(authors_in_train)]
        chunk = chunk.to_dict("records")
        chunk_clean = pool.map(clean, chunk)
        chunk_clean = [elem for elem in chunk_clean if len(elem) > 0]
        records.extend(chunk_clean)
        print("Processed {} -- num_valid_authors={}".format(it, len(records)))
        if len(records) >= num_authors:
            records = records[:num_authors]
            break
    
    if classifier_datapath:
        savename = "/FOO/BAR/data/style_transfer/mtd/MTD_reddit_preference_{}_correct.jsonl".format(num_authors)
    else:
        savename = "/FOO/BAR/data/style_transfer/mtd/MTD_reddit_{}_correct.jsonl".format(num_authors)

    print("Saving to: {}".format(savename))
    with open(savename, "w+") as fout:
        for rec in records:
            fout.write(json.dumps(rec))
            fout.write('\n')
    return 0

if __name__ == "__main__":
    fire.Fire(main)