import json
import random
from itertools import islice
from typing import Iterable, List, Optional

import numpy as np
from datasets import load_dataset


N = 500
SEED_FIRST4 = 42
SEED_LAST4 = 423
LIMIT_SCAN = 20000


def sample_texts(texts: List[str], n: int, seed: int) -> List[str]:
    texts = [str(x).strip() for x in texts if x is not None and str(x).strip()]
    if len(texts) == 0:
        return []
    rng = random.Random(seed)
    if len(texts) <= n:
        rng.shuffle(texts)
        return texts
    return rng.sample(texts, n)


def take_n_texts(stream: Iterable[dict], text_key: str, n: int) -> List[str]:
    out: List[str] = []
    for ex in islice(stream, n):
        t = ex.get(text_key, "")
        if t is None:
            t = ""
        out.append(str(t))
    return out


def load_pubmedqa_texts(n: int, seed: int, limit_scan: int = LIMIT_SCAN) -> List[str]:
    ds = load_dataset("qiaojin/PubMedQA", "pqa_artificial", split="train")
    collected: List[str] = []

    for i, ex in enumerate(ds):
        if i >= limit_scan:
            break
        ctx = ex.get("context")
        if isinstance(ctx, dict) and "contexts" in ctx and isinstance(ctx["contexts"], list):
            collected.extend([str(x).strip() for x in ctx["contexts"] if x is not None])
        else:
            long_answer = ex.get("long_answer")
            if long_answer:
                collected.append(str(long_answer).strip())

    return sample_texts(collected, n=n, seed=seed)


def load_bloomberg_texts(n: int, seed: int, limit_scan: int = LIMIT_SCAN) -> List[str]:
    ds = load_dataset("danidanou/Bloomberg_Financial_News", split="train")
    collected: List[str] = []
    for i, ex in enumerate(ds):
        if i >= limit_scan:
            break
        t = ex.get("Article")
        if t:
            collected.append(str(t).strip())
    return sample_texts(collected, n=n, seed=seed)


def load_arxiv_texts(n: int, seed: int, limit_scan: int = LIMIT_SCAN) -> List[str]:
    ds = load_dataset("ccdv/arxiv-summarization", split="train")
    collected: List[str] = []
    for i, ex in enumerate(ds):
        if i >= limit_scan:
            break
        t = (ex.get("abstract") or "").strip()
        if t:
            collected.append(t)
    return sample_texts(collected, n=n, seed=seed)


def load_lexglue_texts(n: int, seed: int, limit_scan: int = LIMIT_SCAN) -> List[str]:
    ds = load_dataset("lex_glue", "ledgar", split="train")
    collected: List[str] = []
    for i, ex in enumerate(ds):
        if i >= limit_scan:
            break
        t = ex.get("text")
        if t:
            collected.append(str(t).strip())
    return sample_texts(collected, n=n, seed=seed)


def load_bigpatent_texts(
    n: int,
    seed: int,
    split: str = "train",
    code: str = "all",
    use: str = "abstract",
    shuffle_stream: bool = True,
) -> List[str]:
    ds = load_dataset("big_patent", code, split=split, streaming=True)
    if shuffle_stream:
        ds = ds.shuffle(seed=seed, buffer_size=10000)

    collected: List[str] = []
    for ex in islice(ds, n):
        if use == "abstract":
            t = ex.get("abstract", "")
        elif use == "description":
            t = ex.get("description", "")
        else:
            t = f"{ex.get('abstract','')}\n\n{ex.get('description','')}"
        t = str(t).strip()
        if t:
            collected.append(t)

    return collected


def load_imdb_texts(n: int, seed: int) -> List[str]:
    ds = load_dataset("stanfordnlp/imdb", split="train")
    ds = ds.shuffle(seed=seed).select(range(min(n, len(ds))))
    return [str(x.get("text", "")).strip() for x in ds if x.get("text")]


def load_wikipedia_texts(n: int, seed: int) -> List[str]:
    configs = ["20240701.en", "20240401.en", "20240101.en", "20231101.en", "20220301.en"]
    last_err: Optional[Exception] = None
    stream = None

    for cfg in configs:
        try:
            stream = load_dataset("wikimedia/wikipedia", cfg, split="train", streaming=True)
            break
        except Exception as e:
            last_err = e

    if stream is None:
        raise RuntimeError(f"Failed to load wikipedia. Last error: {last_err}")

    stream = stream.shuffle(seed=seed, buffer_size=10000)
    return take_n_texts(stream, text_key="text", n=n)


def resolve_arc_answer(ex: dict) -> str:
    q = (ex.get("question") or "").strip()
    choices = ex.get("choices") or {}
    labels = choices.get("label") or []
    texts = choices.get("text") or []
    ak = str(ex.get("answerKey") or "").strip()

    label2text = {
        str(l).strip(): str(t).strip()
        for l, t in zip(labels, texts)
        if l is not None and t is not None
    }

    ans = label2text.get(ak)

    if ans is None and ak.isdigit():
        idx = int(ak) - 1
        if 0 <= idx < len(texts):
            ans = str(texts[idx]).strip()

    if ans is None:
        abc = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
        if ak in abc:
            idx = abc.index(ak)
            if 0 <= idx < len(texts):
                ans = str(texts[idx]).strip()

    if ans is None:
        ans = ""

    return f"Question: {q}\nAnswer: {ans}".strip()


def load_arc_texts(n: int, seed: int) -> List[str]:
    ds = load_dataset("ai2_arc", "ARC-Challenge", split="train")
    ds = ds.shuffle(seed=seed).select(range(min(n, len(ds))))
    return [resolve_arc_answer(ex) for ex in ds]


def word_count(s: str) -> int:
    return len(str(s).split())


def char_count(s: str) -> int:
    return len(str(s))


def main():
    pubmed_texts = load_pubmedqa_texts(n=N, seed=SEED_FIRST4)
    bloomberg_texts = load_bloomberg_texts(n=N, seed=SEED_FIRST4)
    arxiv_texts = load_arxiv_texts(n=N, seed=SEED_FIRST4)
    lexglue_texts = load_lexglue_texts(n=N, seed=SEED_FIRST4)

    patent_texts = load_bigpatent_texts(n=N, seed=SEED_LAST4, code="all", use="abstract")
    imdb_texts = load_imdb_texts(n=N, seed=SEED_LAST4)
    wiki_texts = load_wikipedia_texts(n=N, seed=SEED_LAST4)
    science_texts = load_arc_texts(n=N, seed=SEED_LAST4)

    datasets = [
        pubmed_texts,
        bloomberg_texts,
        arxiv_texts,
        lexglue_texts,
        patent_texts,
        imdb_texts,
        wiki_texts,
        science_texts,
    ]

    dataset_names = [
        "pubmedqa",
        "bloomberg",
        "arxiv",
        "lexglue",
        "big_patent",
        "imdb",
        "wikipedia",
        "science",
    ]

    with open("test_merge.json", "w", encoding="utf-8") as f:
        json.dump(datasets, f, ensure_ascii=False)

    for name, data in zip(dataset_names, datasets):
        wc = np.array([word_count(x) for x in data], dtype=np.int32)
        cc = np.array([char_count(x) for x in data], dtype=np.int32)
        print(f"{name:12s} avg_words={wc.mean():.2f}  avg_chars={cc.mean():.2f}  n={len(data)}")

    all_texts: List[str] = []
    for d in datasets:
        all_texts.extend(d)

    wc_all = np.array([word_count(x) for x in all_texts], dtype=np.int32)
    cc_all = np.array([char_count(x) for x in all_texts], dtype=np.int32)

    print("\nTOTAL")
    print(f"avg_words={wc_all.mean():.2f}  avg_chars={cc_all.mean():.2f}  n={len(all_texts)}")

if __name__ == "__main__":
    main()
