

import os, re, csv, ast, sys, datetime

# ============ User Arguments ============ #
class Args:
    # Input txt (lines containing Genotype(...))
    input_txt = r""

    # NAS-Bench-201 .pth path
    api_path  = r""

    # Training protocol (must be consistent with how you want to evaluate)
    # For both valid + test -> "cifar10-valid"
    # For full-training test-only -> "cifar10"
    # Others: 'cifar100' | 'ImageNet16-120' (mapping below currently supports cifar10-valid)
    dataset   = "cifar10-valid"

    # Which splits to export (in order): 'train' | 'valid' | 'test'
    eval_sets = ["valid", "test"]

    # Budget key: '01' | '12' | '90' | '200'
    hp        = "200"

    # Stop parsing when encountering a line that starts with "# Frequency"
    # (recommended: only read the unique-list region above that header)
    stop_at_frequency_header = True

    # Deduplicate by arch_str
    dedup = True

    # Evaluate only the first N items (None = evaluate all)
    eval_limit = None

    # Sort criterion (None = no sorting; otherwise one of eval_sets like 'valid' or 'test')
    sort_by = "test"

    # CSV output: directory or full file path; empty string means "no CSV export"
    csv_out = r""
    add_timestamp = True
    timestamp_fmt = "%Y%m%d_%H%M%S"
# ======================================= #


# Search-space op name → NB201 op name
OP_MAP = {
    'conv_1x1':     'nor_conv_1x1',
    'conv_3x3':     'nor_conv_3x3',
    'avg_pool_3x3': 'avg_pool_3x3',
    'skip_connect': 'skip_connect',
    'none':         'none',
}

# NB201 fixed edges (ignore u from logs):
# e0: 0->1 | e1: 0->2 | e2: 1->2 | e3: 0->3 | e4: 1->3 | e5: 2->3
def normal_to_arch_str_fixed_edges(normal):
    """normal: 6 items [[op, u], ...]; map ops to arch_str along fixed NB201 edges."""
    if len(normal) != 6:
        raise ValueError(f"expected 6 ops, got {len(normal)}")
    ops = [OP_MAP.get(op, op) for op, _ in normal]
    g1 = [f"{ops[0]}~0"]
    g2 = [f"{ops[1]}~0", f"{ops[2]}~1"]
    g3 = [f"{ops[3]}~0", f"{ops[4]}~1", f"{ops[5]}~2"]
    return f"|{'|'.join(g1)}|+|{'|'.join(g2)}|+|{'|'.join(g3)}|"


GENO_RE = re.compile(
    r"Genotype\s*\(\s*normal\s*=\s*(\[(?:.|\n)*?\])\s*,\s*normal_concat\s*=\s*(\[(?:.|\n)*?\])\s*\)",
    re.IGNORECASE | re.DOTALL
)
EPOCH_RE = re.compile(r"\[epoch\s*=\s*(\d+)\]", re.IGNORECASE)

def parse_genotypes_from_txt(path, stop_at_freq=True):
    """Return a list of dicts: [{'epoch': int|None, 'normal': [[op,u]×6], 'raw': '...'} ...]."""
    items = []
    with open(path, "r", encoding="utf-8", errors="ignore") as f:
        for line in f:
            if stop_at_freq and line.lstrip().startswith("# Frequency"):
                break
            m = GENO_RE.search(line)
            if not m:
                continue
            try:
                normal = ast.literal_eval(m.group(1))
                # concat = ast.literal_eval(m.group(2))  # not used
            except Exception:
                continue
            ep = None
            mep = EPOCH_RE.search(line)
            if mep:
                try:
                    ep = int(mep.group(1))
                except Exception:
                    ep = None
            items.append({"epoch": ep, "normal": normal, "raw": m.group(0)})
    return items


def _guard_regimen(dataset, eval_sets):
    if dataset == "cifar10" and any(s == "valid" for s in eval_sets):
        raise ValueError("dataset='cifar10' has no 'valid' split. Use dataset='cifar10-valid' for valid+test.")


def _map_eval_key(dataset, setname):
    d, s = dataset, setname
    if d == "cifar10-valid":
        if s == "valid": return "cifar10-valid", "x-valid"
        if s == "test":  return 'cifar10', 'ori-test'
        if s == "train": return "cifar10-valid", "train"
    raise ValueError(f"Unsupported mapping: dataset={dataset}, set={setname}")


def _default_csv_name(args):
    """Default file name: nb201_<dataset>_<evalsets>_hp<hp>_<timestamp>.csv"""
    sets = "-".join(getattr(args, "eval_sets", ["test"]))
    base = f"nb201_{args.dataset}_{sets}_hp{args.hp}"
    if getattr(args, "add_timestamp", True):
        ts = datetime.datetime.now().strftime(getattr(args, "timestamp_fmt", "%Y%m%d_%H%M%S"))
        base = f"{base}_{ts}"
    return base + ".csv"


def _resolve_csv_path(csv_opt, args):

    if not csv_opt:
        return None

    p = csv_opt
    base = os.path.basename(p)
    root, ext = os.path.splitext(base)
    looks_like_dir = (
        p.endswith(("/", "\\")) or
        os.path.isdir(p) or
        (not os.path.exists(p) and ext == "")  # does not exist and no extension → treat as dir
    )

    if looks_like_dir:
        dirpath = p.rstrip("/\\")
        if dirpath and not os.path.exists(dirpath):
            os.makedirs(dirpath, exist_ok=True)
        return os.path.join(dirpath, _default_csv_name(args))

    # Otherwise treat as file path
    parent = os.path.dirname(p)
    if parent and not os.path.exists(parent):
        os.makedirs(parent, exist_ok=True)

    # No extension: prefix + timestamp + .csv
    if ext == "":
        fname = root
        if getattr(args, "add_timestamp", True):
            ts = datetime.datetime.now().strftime(getattr(args, "timestamp_fmt", "%Y%m%d_%H%M%S"))
            fname = f"{fname}_{ts}"
        return os.path.join(parent or ".", fname + ".csv")

    # With extension: insert timestamp before extension (if enabled)
    if getattr(args, "add_timestamp", True):
        ts = datetime.datetime.now().strftime(getattr(args, "timestamp_fmt", "%Y%m%d_%H%M%S"))
        full_root = os.path.join(os.path.dirname(p), os.path.splitext(os.path.basename(p))[0])
        return f"{full_root}_{ts}{ext}"

    return p


def _safe_write_csv(path, header, rows):
    try:
        with open(path, "w", newline="", encoding="utf-8") as f:
            w = csv.writer(f); w.writerow(header); w.writerows(rows)
        print(f"\nCSV exported -> {path}")
    except PermissionError:
        alt_dir = os.path.join(os.path.expanduser("~"), "nb201_exports")
        os.makedirs(alt_dir, exist_ok=True)
        fb = os.path.join(alt_dir, os.path.basename(path))
        with open(fb, "w", newline="", encoding="utf-8") as f:
            w = csv.writer(f); w.writerow(header); w.writerows(rows)
        print(f"\n⚠️ No write permission: {path}\nFallback -> {fb}")


def main(a: Args):
    _guard_regimen(a.dataset, a.eval_sets)

    try:
        from nas_201_api import NASBench201API as API
    except Exception:
        print("Please install nas-bench-201: pip install nas-bench-201", file=sys.stderr)
        raise

    # 1) Parse genotypes
    items = parse_genotypes_from_txt(a.input_txt, stop_at_freq=a.stop_at_frequency_header)
    if a.eval_limit is not None:
        items = items[:int(a.eval_limit)]
    if not items:
        print("[WARN] No Genotype(...) parsed. Please check the txt format.")
        return

    # 2) Convert to arch_str (fixed edges), optionally deduplicate
    archs = []
    seen = set()
    for it in items:
        arch_str = normal_to_arch_str_fixed_edges(it["normal"])
        if (not a.dedup) or (arch_str not in seen):
            archs.append({"epoch": it["epoch"], "arch_str": arch_str})
            seen.add(arch_str)

    print(f"[INFO] To evaluate: {len(archs)} architectures (dedup={a.dedup})")

    # 3) Load API
    api = API(a.api_path, verbose=False)

    # 4) Query results
    # Pre-map evaluation splits
    eval_map = {s: _map_eval_key(a.dataset, s) for s in a.eval_sets}
    rows = []  # printed/exported rows
    data = []  # for sorting

    for i, meta in enumerate(archs, 1):
        ep, arch = meta["epoch"], meta["arch_str"]
        idx = api.query_index_by_arch(arch)
        if idx is None or idx < 0:
            # Should not happen with fixed-edge encoding; still guard it
            accs = [float('nan')] * len(a.eval_sets)
            rows.append([i, ep, idx] + accs + [arch])
            data.append((idx, accs, arch, ep))
            continue

        # Budget-specific info object
        info = api.arch2infos_dict[idx][a.hp]

        accs = []
        for s in a.eval_sets:
            dkey, skey = eval_map[s]
            m = info.get_metrics(dkey, skey, iepoch=None, is_random=False)  # averaged across seeds
            accs.append(m.get("accuracy", float('nan')))

        rows.append([i, ep, idx] + accs + [arch])
        data.append((idx, accs, arch, ep))

    # 5) Sort by a chosen metric
    if a.sort_by is not None:
        if a.sort_by not in a.eval_sets:
            raise ValueError(f"sort_by='{a.sort_by}' must be one of eval_sets={a.eval_sets}")
        k = a.eval_sets.index(a.sort_by)
        # Place NaNs at the end
        data.sort(key=lambda r: ((r[1][k] == r[1][k]), r[1][k]), reverse=True)

    # 6) Print table
    head = ["#", "epoch", "idx"] + [f"acc[{s}]" for s in a.eval_sets] + ["arch_str"]
    print("\t".join(head))
    use_rows = rows
    if a.sort_by is not None:
        # Rebuild rows according to sorted 'data'
        use_rows = []
        for rank, (idx, accs, arch, ep) in enumerate(data, 1):
            use_rows.append([rank, ep, idx] + accs + [arch])
    for r in use_rows:
        # Print acc with 6 decimals
        line = [str(r[0]), str(r[1]), str(r[2])] + [
            f"{x:.6f}" if isinstance(x, (int, float)) and x == x else "nan"  # x==x filters NaN
            for x in r[3:-1]
        ] + [r[-1]]
        print("\t".join(line))

    # 7) CSV export
    csv_path = _resolve_csv_path(a.csv_out, a)
    if csv_path:
        out = use_rows
        _safe_write_csv(csv_path, head, out)


if __name__ == "__main__":
    main(Args)
