
import random; random.seed(43)
import sys

import pandas as pd
from multiprocessing import Pool, cpu_count
from tqdm import tqdm

# NOTE: We do _not_ import spacy or transformers here at top‑level,
# because we want each ProcessPool worker to load them freshly.
nlp = None
tokenizer = None

def init_worker():
    """
    This runs once per worker process.  It loads spaCy and the tokenizer
    into the worker's global namespace.
    """
    global nlp, tokenizer
    import spacy
    from transformers import RobertaTokenizerFast

    nlp = spacy.load("en_core_web_sm", disable=["ner"])
    tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")

def chunk_text(text: str,
               min_tokens: int = 32,
               max_tokens: int = 128) -> list[str]:
    """
    Uses the worker's `nlp` and `tokenizer` to split into sentence‐level
    chunks of size [min_tokens, max_tokens], discarding too‐small chunks.
    """
    doc = nlp(text)
    chunks = []
    curr_sents = []
    curr_len = 0

    for sent in doc.sents:
        sent_toks = tokenizer.tokenize(sent.text)
        L = len(sent_toks)

        if L > max_tokens:
            if curr_len >= min_tokens:
                chunks.append(" ".join(curr_sents).strip())
            curr_sents, curr_len = [], 0
            continue

        if curr_len + L > max_tokens:
            if curr_len >= min_tokens:
                chunks.append(" ".join(curr_sents).strip())
            curr_sents, curr_len = [sent.text], L
        else:
            curr_sents.append(sent.text)
            curr_len += L

    if curr_len >= min_tokens:
        chunks.append(" ".join(curr_sents).strip())

    return chunks

def process_pair(pair):
    """
    pair = (author_id, raw_syms)
    Returns dict with same keys, and syms replaced by filtered/chunked list.
    """
    author_id, raw = pair
    # raw might be a str or a list[str]
    texts = raw if isinstance(raw, list) else [raw]
    out = []

    for s in texts:
        s = s.strip()
        tok_count = len(tokenizer.tokenize(s))

        if 32 <= tok_count <= 128:
            out.append(s)
        elif tok_count > 128:
            out.extend(chunk_text(s, 32, 128))
        # else drop

    return {"author_id": author_id, "syms": out}

def main():
    fname = "/FOO/BAR/data/blog_authorship_corpus/blogtext_author.jsonl"
    df = pd.read_json(fname, lines=True)
    df = (
        df
        .reset_index()
        .rename(columns={"index": "author_id", "text": "syms"})
        [["author_id", "syms"]]
    )

    pairs = list(zip(df["author_id"], df["syms"]))

    n_workers = max(1, cpu_count() - 1)
    with Pool(processes=n_workers, initializer=init_worker) as pool:
        results = []
        for res in tqdm(pool.imap_unordered(process_pair, pairs),
                        total=len(pairs),
                        desc="Chunking texts"):
            results.append(res)

    out_df = pd.DataFrame(results)
    
    out_df.to_json("/FOO/BAR/data/style_transfer/blogs_32-128tok.jsonl", lines=True, orient="records")
    out_df = out_df[out_df["syms"].apply(len) >= 16]
    out_df["syms"] = out_df["syms"].apply(lambda x: random.sample(x, 16))
    out_df.to_json("/FOO/BAR/data/style_transfer/blogs_32-128tok_16post_filtered.jsonl", lines=True, orient="records")

    return 0

if __name__ == "__main__":
    sys.exit(main())
