import os
import argparse
import csv
from typing import Dict, List, Tuple, Optional, Any
import re
from contextlib import nullcontext
import h5py
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from config import *

# -------------------- utility: device & AMP --------------------
def pick_device(device_arg: Optional[str] = None) -> torch.device:
    if device_arg:
        return torch.device(device_arg)
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")

def pick_autocast(precision: str, device: torch.device):
    precision = (precision or "fp32").lower()
    if device.type == "cuda":
        if precision == "fp16":
            return torch.amp.autocast(dtype=torch.float16)
        if precision == "bf16":
            return torch.amp.autocast(dtype=torch.bfloat16)
    return nullcontext()

try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass

# ---------- STRING-like Alias Resolver ----------
_alias_df_cache: Dict[str, pd.DataFrame] = {}

def _load_alias_df(alias_path: str) -> pd.DataFrame:
    """
    Accepts a 'protein info' TSV with columns including:
      - protein_external_id, preferred_name, [annotation]
    Returns a normalized DF with:
      - protein_external_id
      - alias
      - alias_up
    """
    if not alias_path:
        raise FileNotFoundError("Alias file path not specified (use --alias).")
    if alias_path in _alias_df_cache:
        return _alias_df_cache[alias_path]

    df = pd.read_csv(alias_path, sep="\t")
    required = {"protein_external_id", "preferred_name"}
    if not required.issubset(df.columns):
        raise ValueError(f"{alias_path} must contain at least {required}.")

    rows = []
    pref = df[["protein_external_id", "preferred_name"]].dropna().rename(
        columns={"preferred_name": "alias"}
    )
    rows.append(pref)

    USE_ANNOTATION = True
    if USE_ANNOTATION and "annotation" in df.columns:
        tok_re = re.compile(r"[A-Za-z0-9][A-Za-z0-9+_.-]{1,}")
        ann_rows = []
        for _, r in df[["protein_external_id", "annotation"]].dropna().iterrows():
            toks = set(tok_re.findall(str(r["annotation"])))
            if "preferred_name" in r and isinstance(r["preferred_name"], str):
                toks.discard(r["preferred_name"])
            for t in toks:
                ann_rows.append({"protein_external_id": r["protein_external_id"], "alias": t})
        if ann_rows:
            rows.append(pd.DataFrame(ann_rows))

    alias_df = pd.concat(rows, ignore_index=True)
    alias_df["alias"] = alias_df["alias"].astype(str)
    alias_df["alias_up"] = alias_df["alias"].str.upper()
    alias_df = alias_df.drop_duplicates(subset=["protein_external_id", "alias_up"], keep="first")

    _alias_df_cache[alias_path] = alias_df
    return alias_df

def resolve_h5_key(name_or_id: str, h5_keys: set, alias_path: Optional[str] = None,
                   strict_id_only: bool = False) -> str:
    """
    Resolves user-provided identifiers (ID or alias) into actual HDF5 keys.
    - Accepts bare IDs or '9606.'-prefixed; will auto-prefix if needed.
    - If strict_id_only=True, only accept exact HDF5 keys (no alias resolution).
    """
    cand = (name_or_id or "").strip()
    if cand in h5_keys:
        return cand
    if not cand.startswith("9606.") and f"9606.{cand}" in h5_keys:
        return f"9606.{cand}"

    if strict_id_only:
        raise KeyError(f"'{name_or_id}' is not an HDF5 key and strict_id_only=True.")

    if not alias_path:
        raise KeyError(f"'{name_or_id}' not resolved: provide --alias or use strict IDs.")

    df = _load_alias_df(alias_path)
    hits = df[df["alias_up"] == cand.upper()]["protein_external_id"].unique().tolist()
    for ext_id in hits:
        if ext_id in h5_keys:
            return ext_id
        if not ext_id.startswith("9606.") and f"9606.{ext_id}" in h5_keys:
            return f"9606.{ext_id}"
    raise KeyError(f"'{name_or_id}' not resolvable to any HDF5 key.")

# --------- HDF5 embedding I/O ---------
def _ensure_3d(z: np.ndarray) -> np.ndarray:
    if z.ndim == 2:
        z = z[None, ...]
    if z.ndim != 3:
        raise ValueError(f"Expected (1, L, D) or (L, D); got {z.shape}")
    return z

def _as_2d(z: np.ndarray) -> np.ndarray:
    return z[0] if z.ndim == 3 else z

def get_embeddings_from_h5(
    h5_path: str,
    p_name: str,
    kp_names: List[str],
    cand_list: Optional[List[str]] = None,
    alias_path: Optional[str] = None
) -> Tuple[str, np.ndarray, Dict[str, np.ndarray], Dict[str, np.ndarray]]:
    """
    Returns:
      p_key, Z_p,
      anchors: {pk_key: Z_pk},
      candidates: {cand_key: Z_c}
    Ensures CP(p) excludes p and KP(p). If cand_list is None -> all others.
    """
    with h5py.File(h5_path, "r") as f:
        keys = set(f.keys())

        p_key = resolve_h5_key(p_name, keys, alias_path=alias_path, strict_id_only=False)
        Z_p = np.array(f[p_key])

        anchors: Dict[str, np.ndarray] = {}
        for name in kp_names:
            pk_key = resolve_h5_key(name, keys, alias_path=alias_path, strict_id_only=False)
            anchors[pk_key] = np.array(f[pk_key])

        forbidden = set([p_key]) | set(anchors.keys())

        if cand_list is None:
            candidates = {k: np.array(f[k]) for k in keys if k not in forbidden}
        else:
            resolved, missing = {}, []
            for c in cand_list:
                try:
                    ck = resolve_h5_key(c, keys, alias_path=alias_path, strict_id_only=True)
                    if ck in forbidden:
                        continue
                    resolved[ck] = np.array(f[ck])
                except KeyError:
                    missing.append(c)
            if missing:
                raise KeyError(
                    f"Candidates not found (not IDs or absent in H5): "
                    f"{missing[:10]}{'...' if len(missing) > 10 else ''}"
                )
            candidates = resolved

    return p_key, Z_p, anchors, candidates