import argparse
import json
import math
import os
import sys
from collections import Counter, defaultdict
from typing import Any, Dict, Iterable, List, Optional, Tuple

import numpy as np
from scipy.optimize import linear_sum_assignment  # Standard Hungarian algorithm


# -------------------------
# JSONL I/O
# -------------------------

def read_jsonl(path: str) -> Iterable[Dict[str, Any]]:
    with open(path, "r", encoding="utf-8") as f:
        for ln, line in enumerate(f, 1):
            s = line.strip()
            if not s:
                continue
            try:
                yield json.loads(s)
            except json.JSONDecodeError:
                sys.stderr.write(f"[warn] Invalid JSON at line {ln}\n")

def write_jsonl(path: str, records: List[Dict[str, Any]]) -> None:
    os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
    with open(path, "a", encoding="utf-8") as f:
        for r in records:
            f.write(json.dumps(r, ensure_ascii=False) + "\n")


# -------------------------
# Character n-gram distance (logic preserved)
# -------------------------

def normalize(s: str) -> str:
    # Minimal normalization: remove whitespace
    return "".join(s.split())

def char_ngrams(s: str, nk: int = 2):
    s = normalize(s)
    res = []
    for n in range(1, nk + 1):
        if len(s) < n:
            continue
        res.extend([s[i:i+n] for i in range(len(s)-n+1)])
    return res

def cosine_distance_by_char_ngrams(a: str, b: str, n: int = 2):
    ga = Counter(char_ngrams(a, n))
    gb = Counter(char_ngrams(b, n))
    if not ga and not gb:
        return 0.0
    # Dot product
    dot = sum(ga[g] * gb.get(g, 0) for g in ga)
    # Norms
    na = math.sqrt(sum(v*v for v in ga.values()))
    nb = math.sqrt(sum(v*v for v in gb.values()))
    if na == 0 or nb == 0:
        return 1.0
    cos = dot / (na * nb)
    return 1.0 - cos  # Distance


# -------------------------
# DTW with external per-column weights (logic preserved)
# -------------------------

def dtw_with_weight(para):
    sequence1, sequence2, n_gram, weight = para
    n, m = len(sequence1), len(sequence2)
    if n == 0 or m == 0:
        return 1.0
    assert len(weight) == m, "weight length must equal len(sequence2)"

    # D: cumulative weighted cost (numerator)
    # W: cumulative weight sum on the same path (denominator; a column j may repeat)
    D = [[0.0]*(m+1) for _ in range(n+1)]
    W = [[0.0]*(m+1) for _ in range(n+1)]

    # Initialize first column: align s1[:i] repeatedly to s2[0] (reuse weight[0])
    w0 = weight[0]
    for i in range(1, n+1):
        d = cosine_distance_by_char_ngrams(sequence1[i-1], sequence2[0], n_gram)
        D[i][0] = D[i-1][0] + w0 * d
        W[i][0] = W[i-1][0] + w0

    # Initialize first row: advance s2[:j] aligned with s1[0]
    for j in range(1, m+1):
        wj = weight[j-1]
        d = cosine_distance_by_char_ngrams(sequence1[0], sequence2[j-1], n_gram)
        D[0][j] = D[0][j-1] + wj * d
        W[0][j] = W[0][j-1] + wj

    # Main loop: choose predecessor by D only (tie: diag -> left -> up). W follows the same predecessor.
    for i in range(1, n+1):
        si = sequence1[i-1]
        for j in range(1, m+1):
            wj = weight[j-1]
            d  = cosine_distance_by_char_ngrams(si, sequence2[j-1], n_gram)
            d_up, d_left, d_diag = D[i-1][j], D[i][j-1], D[i-1][j-1]
            if d_diag <= d_left and d_diag <= d_up:
                prevD, prevW = D[i-1][j-1], W[i-1][j-1]
            elif d_left <= d_up:
                prevD, prevW = D[i][j-1],   W[i][j-1]
            else:
                prevD, prevW = D[i-1][j],   W[i-1][j]
            D[i][j] = prevD + wj * d
            W[i][j] = prevW + wj

    denom = W[n][m]
    return D[n][m] / denom if denom > 0 else 0.0


# -------------------------
# Pattern parsing (supports English/Chinese schemas)
# -------------------------

def extract_pattern_chain(obj: Any) -> List[str]:
    """Return a list of pattern names in order. Fall back to listing names if no chain."""
    if obj is None:
        return []
    if isinstance(obj, str):
        try:
            obj = json.loads(obj)
        except Exception:
            return [obj]

    # English schema
    if isinstance(obj, dict) and "pattern_list" in obj:
        plist = obj.get("pattern_list") or []
        id2name = {}
        for p in plist:
            pid, nm = p.get("id"), p.get("name")
            if nm is not None:
                id2name[pid] = str(nm)
        chain = None
        how = obj.get("how_CoT_utilizes_patterns_in_this_case")
        if isinstance(how, dict):
            chain = how.get("pattern_chain")
        if isinstance(chain, list) and id2name:
            return [id2name.get(pid, str(pid)) for pid in chain]
        return [str(p.get("name")) for p in plist if p.get("name") is not None]

    # Chinese schema
    if isinstance(obj, dict) and "模式列表" in obj:
        plist = obj.get("模式列表") or []
        id2name = {}
        for p in plist:
            pid, nm = p.get("编号"), p.get("名称")
            if nm is not None:
                id2name[pid] = str(nm)
        chain = None
        how = obj.get("本题CoT如何利用这些模式")
        if isinstance(how, dict):
            chain = how.get("模式链")
        if isinstance(chain, list) and id2name:
            return [id2name.get(pid, str(pid)) for pid in chain]
        return [str(p.get("名称")) for p in plist if p.get("名称") is not None]

    # Unknown dict: collect any "name"/"名称" values
    if isinstance(obj, dict):
        names: List[str] = []
        def walk(x):
            if isinstance(x, dict):
                for k, v in x.items():
                    if k in ("name", "名称") and isinstance(v, (str, int, float)):
                        names.append(str(v))
                    else:
                        walk(v)
            elif isinstance(x, list):
                for y in x:
                    walk(y)
        walk(obj)
        return names

    return [str(obj)]


# -------------------------
# Entropy chain tokenization for DTW
# -------------------------

def extract_entropy_list(rec: Dict[str, Any]) -> List[float]:
    seq = rec.get("output_token_entropies") or []
    vals: List[float] = []
    for x in seq:
        try:
            vals.append(float(x.get("entropy_nats", 0.0)))
        except Exception:
            vals.append(0.0)
    return vals

def entropy_sequence_to_tokens(entropies: List[float],
                               scale: float = 1000.0,
                               width: int = 6) -> List[str]:
    """
    Convert entropy floats to string tokens so char-gram DTW can be applied.
    Example: 0.1234 -> "E000123" with scale=1000, width=6.
    """
    toks: List[str] = []
    for e in entropies:
        try:
            k = int(round(e * scale))
        except Exception:
            k = 0
        toks.append(f"E{abs(k):0{width}d}")
    return toks


# -------------------------
# Distance components and combination
# -------------------------

def pattern_distance(core_rec: Dict[str, Any],
                     src_rec: Dict[str, Any],
                     ngram_n: int,
                     epsilon_w: float) -> float:
    """
    DTW distance over pattern chains with weights from Core 'pattern_with_weight'.
    sequence1 = source_chain
    sequence2 = core_chain
    weight    = [w_core(p) for p in core_chain], floored by epsilon_w
    """
    core_chain = extract_pattern_chain(core_rec.get("pattern"))
    src_chain  = extract_pattern_chain(src_rec.get("pattern"))

    wmap = core_rec.get("pattern_with_weight") or {}
    weights: List[float] = []
    for p in core_chain:
        w = wmap.get(p, None)
        try:
            v = float(w) if w is not None else epsilon_w
        except Exception:
            v = epsilon_w
        if v <= 0:
            v = epsilon_w
        weights.append(v)

    return dtw_with_weight((src_chain, core_chain, ngram_n, weights))

def entropy_distance(core_rec: Dict[str, Any],
                     src_rec: Dict[str, Any],
                     ngram_n: int,
                     ent_scale: float,
                     ent_width: int) -> float:
    """
    DTW distance over entropy token sequences with uniform weights on Core.
    sequence1 = src_entropy_tokens
    sequence2 = core_entropy_tokens
    weight    = [1.0] * len(core_entropy_tokens)
    """
    core_vals = extract_entropy_list(core_rec)
    src_vals  = extract_entropy_list(src_rec)

    core_tokens = entropy_sequence_to_tokens(core_vals, ent_scale, ent_width)
    src_tokens  = entropy_sequence_to_tokens(src_vals,  ent_scale, ent_width)

    weights = [1.0] * len(core_tokens)
    return dtw_with_weight((src_tokens, core_tokens, ngram_n, weights))

def combined_distance(core_rec: Dict[str, Any],
                      src_rec: Dict[str, Any],
                      lambda_weight: float,
                      ngram_n: int,
                      epsilon_w: float,
                      ent_scale: float,
                      ent_width: int) -> float:
    lam = max(0.0, min(1.0, float(lambda_weight)))
    dp = pattern_distance(core_rec, src_rec, ngram_n, epsilon_w)
    de = entropy_distance(core_rec, src_rec, ngram_n, ent_scale, ent_width)
    return lam * dp + (1.0 - lam) * de


# -------------------------
# CLI
# -------------------------

def parse_args() -> argparse.Namespace:
    ap = argparse.ArgumentParser(
        description="Hungarian matching between Core and Source CoTs using weighted DTW distances."
    )
    ap.add_argument("--core_path", type=str, required=True, help="Path to Core Set JSONL.")
    ap.add_argument("--source_path", type=str, required=True, help="Path to Source Data JSONL.")
    ap.add_argument("--output_path", type=str, required=True, help="Output JSONL for selected Source records.")
    ap.add_argument("--lambda_weight", type=float, default=0.5, help="Lambda in [0,1] for combining distances.")
    ap.add_argument("--ngram_n", type=int, default=2, help="Max character n for char-gram distance.")
    ap.add_argument("--pattern_weight_epsilon", type=float, default=1e-6, help="Floor for missing/zero Core pattern weights.")
    ap.add_argument("--entropy_scale", type=float, default=1000.0, help="Scaling factor for entropy tokenization.")
    ap.add_argument("--entropy_width", type=int, default=6, help="Zero-pad width for entropy tokens.")
    return ap.parse_args()


# -------------------------
# Main
# -------------------------

def main() -> None:
    args = parse_args()

    core_list   = list(read_jsonl(args.core_path))
    source_list = list(read_jsonl(args.source_path))

    if not core_list:
        sys.stderr.write("[error] Core Set is empty.\n")
        return
    if not source_list:
        sys.stderr.write("[error] Source Data is empty.\n")
        return
    n, m = len(core_list), len(source_list)
    if m < n:
        sys.stderr.write(f"[error] Source size {m} < Core size {n}.\n")
        return

    # Build n x m cost matrix
    cost = np.zeros((n, m), dtype=np.float64)
    for i, core_rec in enumerate(core_list):
        for j, src_rec in enumerate(source_list):
            dij = combined_distance(
                core_rec=core_rec,
                src_rec=src_rec,
                lambda_weight=args.lambda_weight,
                ngram_n=args.ngram_n,
                epsilon_w=args.pattern_weight_epsilon,
                ent_scale=args.entropy_scale,
                ent_width=args.entropy_width
            )
            cost[i, j] = float(dij)

    # Hungarian algorithm via SciPy
    row_ind, col_ind = linear_sum_assignment(cost)  # chooses min-sum assignment

    # Expect |row_ind| == n when m >= n
    if len(row_ind) != n:
        sys.stderr.write(f"[warn] Assignment covered {len(row_ind)} of {n} rows.\n")

    # Collect selected Source records in Core row order
    selected: List[Dict[str, Any]] = []
    for i, j in zip(row_ind, col_ind):
        if j is None or j < 0 or j >= m:
            continue
        src = source_list[j]
        out = {
            "identity": src.get("identity"),
            "question": src.get("question"),
            "question_type": src.get("question_type"),
            "answer": src.get("answer"),
            "cot": src.get("cot"),
            "pattern": src.get("pattern"),
            "output_token_entropies": src.get("output_token_entropies"),
        }
        selected.append(out)

    write_jsonl(args.output_path, selected)
    sys.stderr.write(f"[info] Wrote {len(selected)} records -> {args.output_path}\n")


if __name__ == "__main__":
    main()
