import argparse
import json
import math
import os
import sys
from collections import Counter, defaultdict
from typing import Any, Dict, Iterable, List, Optional, Tuple

# -------------------------
# I/O
# -------------------------

def read_jsonl(path: str) -> Iterable[Dict[str, Any]]:
    with open(path, "r", encoding="utf-8") as f:
        for ln, line in enumerate(f, 1):
            line = line.strip()
            if not line:
                continue
            try:
                yield json.loads(line)
            except json.JSONDecodeError:
                sys.stderr.write(f"[warn] Invalid JSON at line {ln}\n")

def write_jsonl(path: str, records: List[Dict[str, Any]]) -> None:
    os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
    with open(path, "a", encoding="utf-8") as f:
        for r in records:
            f.write(json.dumps(r, ensure_ascii=False) + "\n")

# -------------------------
# Pattern extraction (robust to English/Chinese schemas)
# -------------------------

def extract_patterns_from_record(rec: Dict[str, Any]) -> List[str]:
    """
    Prefer the explicit pattern_chain if available (mapped via id->name).
    Fallback to the list of names. Returns a multiset (list) for counting.
    """
    pat = rec.get("pattern", None)
    if pat is None:
        return []

    # If pattern is a JSON-encoded string, try parsing.
    if isinstance(pat, str):
        try:
            pat = json.loads(pat)
        except Exception:
            # If not JSON, treat as raw name string.
            return [pat]

    names: List[str] = []

    # English schema
    if isinstance(pat, dict) and "pattern_list" in pat:
        plist = pat.get("pattern_list") or []
        id2name = {}
        for p in plist:
            pid = p.get("id")
            nm = p.get("name")
            if nm is None:
                continue
            id2name[pid] = str(nm)
        chain = None
        how = pat.get("how_CoT_utilizes_patterns_in_this_case")
        if isinstance(how, dict):
            chain = how.get("pattern_chain")
        if isinstance(chain, list) and id2name:
            for pid in chain:
                nm = id2name.get(pid)
                if nm:
                    names.append(nm)
        else:
            # Fallback: all names once
            names.extend([str(p.get("name")) for p in plist if p.get("name") is not None])
        return names

    # Chinese schema
    if isinstance(pat, dict) and "模式列表" in pat:
        plist = pat.get("模式列表") or []
        id2name = {}
        for p in plist:
            pid = p.get("编号")
            nm = p.get("名称")
            if nm is None:
                continue
            id2name[pid] = str(nm)
        chain = None
        how = pat.get("本题CoT如何利用这些模式")
        if isinstance(how, dict):
            chain = how.get("模式链")
        if isinstance(chain, list) and id2name:
            for pid in chain:
                nm = id2name.get(pid)
                if nm:
                    names.append(nm)
        else:
            names.extend([str(p.get("名称")) for p in plist if p.get("名称") is not None])
        return names

    # Unknown dict schema: try best-effort extraction of any "name"/"名称" values
    if isinstance(pat, dict):
        bag = []
        def walk(x):
            if isinstance(x, dict):
                for k, v in x.items():
                    if k in ("name", "名称") and isinstance(v, (str, int, float)):
                        bag.append(str(v))
                    else:
                        walk(v)
            elif isinstance(x, list):
                for y in x:
                    walk(y)
        walk(pat)
        return bag if bag else [json.dumps(pat, ensure_ascii=False)]

    # Fallback
    return [str(pat)]

# -------------------------
# Character n-gram similarity (provided logic)
# -------------------------

def normalize(s: str) -> str:
    return "".join(s.split())

def char_ngrams(s: str, nk: int = 2) -> List[str]:
    s = normalize(s)
    res = []
    for n in range(1, nk + 1):
        if len(s) < n:
            continue
        res.extend([s[i:i+n] for i in range(len(s) - n + 1)])
    return res

def cosine_similarity_by_char_ngrams(a: str, b: str, n: int = 2) -> float:
    ga = Counter(char_ngrams(a, n))
    gb = Counter(char_ngrams(b, n))
    if not ga and not gb:
        return 1.0
    if not ga or not gb:
        return 0.0
    dot = sum(ga[g] * gb.get(g, 0) for g in ga)
    na = math.sqrt(sum(v * v for v in ga.values()))
    nb = math.sqrt(sum(v * v for v in gb.values()))
    if na == 0 or nb == 0:
        return 0.0
    return dot / (na * nb)

# -------------------------
# Union-Find for pattern clustering
# -------------------------

class DSU:
    def __init__(self, n: int):
        self.p = list(range(n))
        self.sz = [1] * n
    def find(self, x: int) -> int:
        while self.p[x] != x:
            self.p[x] = self.p[self.p[x]]
            x = self.p[x]
        return x
    def union(self, a: int, b: int) -> None:
        ra, rb = self.find(a), self.find(b)
        if ra == rb:
            return
        if self.sz[ra] < self.sz[rb]:
            ra, rb = rb, ra
        self.p[rb] = ra
        self.sz[ra] += self.sz[rb]

def cluster_patterns_globally(raw_patterns: List[str],
                              sim_threshold: float,
                              ngram_n: int) -> Tuple[Dict[str, int], Dict[int, str]]:
    """
    Cluster unique pattern strings by cosine similarity >= threshold.
    Returns:
        name2cid: map raw name -> cluster id
        cid2rep:  cluster id -> representative name (most frequent, tie -> shortest)
    """
    uniq = list(dict.fromkeys(raw_patterns))
    m = len(uniq)
    dsu = DSU(m)

    # Pairwise union by threshold
    for i in range(m):
        ai = uniq[i]
        for j in range(i + 1, m):
            aj = uniq[j]
            sim = cosine_similarity_by_char_ngrams(ai, aj, n=ngram_n)
            if sim >= sim_threshold:
                dsu.union(i, j)

    # Build clusters
    rep_counts = Counter(raw_patterns)
    name2cid: Dict[str, int] = {}
    buckets: Dict[int, List[str]] = defaultdict(list)
    for idx, name in enumerate(uniq):
        cid = dsu.find(idx)
        name2cid[name] = cid
        buckets[cid].append(name)

    # Representative selection: highest frequency, then shortest string
    cid2rep: Dict[int, str] = {}
    for cid, names in buckets.items():
        names_sorted = sorted(names, key=lambda s: (-rep_counts[s], len(s), s))
        cid2rep[cid] = names_sorted[0]
    return name2cid, cid2rep

# -------------------------
# TF-IDF computation
# -------------------------

def compute_tfidf(
    qid_to_counts: Dict[str, Counter],
    qid_to_presence: Dict[str, set]
) -> Dict[str, Dict[str, float]]:
    """
    Compute TF-IDF per question for canonical patterns.

    Returns:
        qid2weights: {qid: {canonical_pattern: tfidf}}
    """
    Q = len(qid_to_presence)
    # Document frequency per pattern (across questions)
    df = Counter()
    for pres in qid_to_presence.values():
        for pat in pres:
            df[pat] += 1

    qid2weights: Dict[str, Dict[str, float]] = {}
    for qid, cnts in qid_to_counts.items():
        total = sum(cnts.values())
        if total <= 0:
            qid2weights[qid] = {}
            continue
        weights = {}
        for pat, c in cnts.items():
            tf = c / total
            denom = df.get(pat, 0)
            # Guard: if denom==0 (should not happen), skip
            if denom <= 0:
                continue
            idf = math.log(Q / denom)
            weights[pat] = tf * idf
        qid2weights[qid] = weights
    return qid2weights

# -------------------------
# Main pipeline
# -------------------------

def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser(
        description="Compute per-question TF-IDF of reasoning patterns over multiple CoTs."
    )
    p.add_argument("--input_path", type=str, required=True, help="Input JSONL file.")
    p.add_argument("--output_dir", type=str, required=True, help="Directory to write JSONL output.")
    p.add_argument("--model_path", type=str, required=False, default="", help="Unused. Kept for interface symmetry.")
    p.add_argument("--similarity_threshold", type=float, default=0.85, help="N-gram cosine similarity threshold for merging patterns.")
    p.add_argument("--ngram_n", type=int, default=2, help="Maximum character n for n-gram similarity.")
    return p.parse_args()

def main() -> None:
    args = parse_args()

    # Pass 1: load all records, group by question, collect raw pattern strings
    records: List[Dict[str, Any]] = list(read_jsonl(args.input_path))
    by_qid: Dict[str, List[int]] = defaultdict(list)  # qid -> indices in records
    rec_patterns: List[List[str]] = []               # per-record raw pattern names
    all_raw_patterns: List[str] = []

    for idx, rec in enumerate(records):
        qid = str(rec.get("identity", ""))
        by_qid[qid].append(idx)
        pats = extract_patterns_from_record(rec)
        rec_patterns.append(pats)
        all_raw_patterns.extend(pats)

    # If no patterns at all, exit early with passthrough
    if not all_raw_patterns:
        sys.stderr.write("[info] No patterns found. Writing passthrough output.\n")
        out_path = os.path.join(args.output_dir, os.path.basename(args.input_path))
        write_jsonl(out_path, records)
        return

    # Global clustering of pattern names by n-gram similarity
    name2cid, cid2rep = cluster_patterns_globally(
        raw_patterns=all_raw_patterns,
        sim_threshold=args.similarity_threshold,
        ngram_n=args.ngram_n
    )

    # Pass 2: aggregate counts and presence per question using canonical names
    qid_to_counts: Dict[str, Counter] = defaultdict(Counter)
    qid_to_presence: Dict[str, set] = defaultdict(set)

    for qid, idx_list in by_qid.items():
        for idx in idx_list:
            for raw_name in rec_patterns[idx]:
                cid = name2cid.get(raw_name)
                if cid is None:
                    continue
                canon = cid2rep[cid]
                qid_to_counts[qid][canon] += 1  # count multiplicity across replicas
                qid_to_presence[qid].add(canon)

    # Compute TF-IDF weights per question
    qid2weights = compute_tfidf(qid_to_counts, qid_to_presence)

    # Pass 3: write output with 1-based replica index and per-question weights
    out_path = os.path.join(args.output_dir, os.path.basename(args.input_path))
    os.makedirs(args.output_dir, exist_ok=True)

    batch: List[Dict[str, Any]] = []
    written = 0
    for qid, idx_list in by_qid.items():
        weights = qid2weights.get(qid, {})
        # Stable order as in input; assign 1-based replica
        for ridx, rec_idx in enumerate(idx_list, start=1):
            rec = records[rec_idx]
            out = {
                "identity": qid,
                "replica": ridx,
                "question": rec.get("question"),
                "question_type": rec.get("question_type"),
                "answer": rec.get("answer"),
                "cot": rec.get("cot"),
                "pattern": rec.get("pattern"),
                "pattern_with_weight": weights,
            }
            batch.append(out)
            if len(batch) >= 512:
                write_jsonl(out_path, batch)
                written += len(batch)
                batch.clear()
    if batch:
        write_jsonl(out_path, batch)
        written += len(batch)

    sys.stderr.write(f"[info] Wrote {written} records -> {out_path}\n")

if __name__ == "__main__":
    main()
