import os
# ---- set thread env BEFORE importing numpy/sklearn/torch ----
os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")
os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
os.environ.setdefault("NUMEXPR_NUM_THREADS", "1")

import argparse
import random
import numpy as np

def _make_deterministic(seed: int):
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    try:
        import torch
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)
        torch.use_deterministic_algorithms(True, warn_only=True)
        if hasattr(torch.backends, "cudnn"):
            torch.backends.cudnn.benchmark = False
            torch.backends.cudnn.deterministic = True
    except Exception:
        pass

from hhmm_lib import (
    load_pt_records,
    build_top_sequences,
    fit_hhmm_fixed_top,
    coerce_labels_to_ids,
    CANON_TAGS,  # Import to check tag list
)
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--in_pt", required=True)
    # FIXED: C default now matches number of canonical tags (could be 4 or 5)
    ap.add_argument("--C", type=int, default=None, 
                    help="#categories (top). Default: auto-detect from CANON_TAGS")
    ap.add_argument("--K", type=int, default=7, help="#regimes per category (bottom)")
    ap.add_argument("--iters", type=int, default=10)
    ap.add_argument("--seed", type=int, default=42)
    ap.add_argument("--out_npz", default="hhmm_model.npz")
    ap.add_argument("--label_key", default="sentences_with_labels",
                    help="Per-step labels field used to anchor top categories.")
    ap.add_argument("--pca_dim", type=int, default=64, help="PCA target dim (<= D).")
    ap.add_argument("--subset", choices=["all", "correct", "incorrect"], default="all",
                    help="Which data to train on")
    args = ap.parse_args()

    if args.C is None:
        args.C = len(CANON_TAGS)
        print(f"[INFO] Auto-detected C={args.C} from CANON_TAGS: {CANON_TAGS}")
    else:
        if args.C != len(CANON_TAGS):
            print(f"[WARNING] --C={args.C} but CANON_TAGS has {len(CANON_TAGS)} tags: {CANON_TAGS}")
            print(f"[WARNING] This may cause issues if your data contains all {len(CANON_TAGS)} tags.")

    _make_deterministic(args.seed)

    # ---- load + filter records ----
    recs_all = load_pt_records(args.in_pt)

    if args.subset == "all":
        recs = recs_all
    else:
        want_correct = (args.subset == "correct")
        recs = [r for r in recs_all if bool(r.get("is_correct", False)) == want_correct]
    print(f"Loaded {len(recs)} records after subset='{args.subset}' filtering.")

    # ---- sequences (always anchored) ----
    seqs = build_top_sequences(recs)

    before = len(seqs)
    seqs = [s for s in seqs if args.label_key in s]
    after = len(seqs)
    if after == 0:
        raise RuntimeError(
            f"No sequences contain '{args.label_key}'. "
            "Always-anchored training requires labels."
        )
    print(f"Anchored mode: kept {after}/{before} sequences with labels ({args.label_key}).")

    # ADDED: Sanity-check labels and report statistics
    print(f"\n[INFO] Analyzing labels in dataset...")
    label_counts = {i: 0 for i in range(args.C)}
    total_steps = 0
    for s in seqs:
        labels = coerce_labels_to_ids(s.get(args.label_key))
        for lbl in labels:
            if lbl < args.C:
                label_counts[lbl] += 1
            else:
                print(f"[WARNING] Found label {lbl} >= C={args.C}, this will cause errors!")
            total_steps += 1
    
    print(f"Label distribution across {total_steps} steps:")
    for i in range(args.C):
        tag_name = CANON_TAGS[i] if i < len(CANON_TAGS) else f"category_{i}"
        count = label_counts[i]
        pct = 100.0 * count / max(total_steps, 1)
        print(f"  {i} ({tag_name:30s}): {count:6d} steps ({pct:5.2f}%)")
    
    # Check for empty categories
    empty_cats = [i for i in range(args.C) if label_counts[i] == 0]
    if empty_cats:
        print(f"\n[WARNING] The following categories have NO data: {empty_cats}")
        print(f"[WARNING] Training may fail or produce degenerate results for these categories.")
        print(f"[WARNING] Consider using --C={args.C - len(empty_cats)} to exclude empty categories.")

    # ---- feature preprocessing (deterministic) ----
    X = np.concatenate([x for seq in seqs for x in seq["steps"]], axis=0)  # [N*L, D]
    D_in = X.shape[1]
    scaler = StandardScaler(copy=True, with_mean=True, with_std=True).fit(X)
    X_scaled = scaler.transform(X)
    k = min(args.pca_dim, D_in)
    pca = PCA(n_components=k, svd_solver="full", random_state=args.seed).fit(X_scaled)

    print(f"\n[INFO] Feature preprocessing:")
    print(f"  Input dimension: {D_in}")
    print(f"  PCA dimension: {k}")
    print(f"  Explained variance ratio (first 5 components): {pca.explained_variance_ratio_[:5]}")

    for seq in seqs:
        new_steps = []
        for x in seq["steps"]:
            z = scaler.transform(x)
            z = pca.transform(z)
            new_steps.append(z.astype(np.float64))
        seq["steps"] = new_steps

    # ---- Train HHMM (always passes label_key) ----
    print(f"\n[INFO] Training HHMM with C={args.C}, K={args.K}...")
    model = fit_hhmm_fixed_top(
        seqs,
        C=args.C,
        K=args.K,
        label_key=args.label_key,
        n_iter=args.iters,
        seed=args.seed,
        verbose=True,
    )

    # ---- Save model + preproc ----
    out = {
        "C": np.array([model.C], dtype=np.int32),
        "K": np.array([model.K], dtype=np.int32),
        "D": np.array([model.D], dtype=np.int32),
        "top_start": model.top.startprob,
        "top_trans": model.top.transmat,
        **{f"b{c}_start": model.bottom[c].startprob for c in range(model.C)},
        **{f"b{c}_trans": model.bottom[c].transmat  for c in range(model.C)},
        **{f"b{c}_means": model.bottom[c].means     for c in range(model.C)},
        **{f"b{c}_vars":  model.bottom[c].variances for c in range(model.C)},
        # preprocessing
        "prep_mean": scaler.mean_.astype(np.float64),
        "prep_scale": scaler.scale_.astype(np.float64),
        "prep_pca_components": pca.components_.astype(np.float64),
        "prep_pca_mean": pca.mean_.astype(np.float64),
        "prep_pca_explained_variance": pca.explained_variance_.astype(np.float64),
        "prep_pca_explained_variance_ratio": pca.explained_variance_ratio_.astype(np.float64),
        "prep_pca_singular_values": pca.singular_values_.astype(np.float64),
        "meta_subset": np.array(args.subset),
        "meta_canon_tags": np.array(CANON_TAGS, dtype=object),
    }
    np.savez(args.out_npz, **out)
    print(f"\n[INFO] Saved HHMM model to {args.out_npz} (subset={args.subset})")
    print(f"\n[RESULTS] Top-level HMM parameters:")
    print(f"  Start probabilities:")
    for i in range(model.C):
        tag_name = CANON_TAGS[i] if i < len(CANON_TAGS) else f"category_{i}"
        print(f"    {i} ({tag_name:30s}): {model.top.startprob[i]:.4f}")
    
    print(f"\n  Transition matrix (row = from, col = to):")
    # Print header
    header = "      " + "".join(f"{i:7d}" for i in range(model.C))
    print(header)
    with np.printoptions(precision=3, suppress=True):
        for i in range(model.C):
            row_str = f"  {i:2d}  " + "".join(f"{model.top.transmat[i,j]:7.3f}" for j in range(model.C))
            print(row_str)

if __name__ == "__main__":
    main()