#!/usr/bin/env python
"""
Evolutionary PDZ-tail inference (final-projection stage).

Add to your YAML (if not already present):
projection:
  evol_seeds:     10      # number of parallel seeds
  evol_interval:   5      # evaluate / clone every k reverse steps
  start_step:     25      # last timestep included in evolutionary phase
"""

import os, re, time, pickle, random, glob, csv, logging
os.environ["TORCH_CPP_LOG_LEVEL"] = "ERROR"
import warnings
warnings.filterwarnings("ignore")
import math
import numpy as np
import torch
from omegaconf import OmegaConf
import hydra
import copy
from typing import Optional
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

# ---- all helper functions / classes live in your init script ----------
from scripts_adapt.c_pdz_eff_i_ref_new import *   #  projection_pdz, penalties, etc.
from constraints.projection import projection_pdz

# ---------------------------------------------------------------------- #
#  reproducibility helper
# ---------------------------------------------------------------------- #
def make_deterministic(seed=0):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    
    
def breaks_and_beta_from_tensor(
    bb_tensor: torch.Tensor,
    beta_start: int = 90,
    beta_end: int = 100,
    break_thresh: float = 4.5,
    min_beta_len: int = 4,
) -> Tuple[int, int]:
    """Return *(num_breaks, β‑strand length in window)* from a backbone tensor.

    Parameters
    ----------
    bb_tensor : torch.Tensor
        Shape ``[n_res, 4, 3]`` (N, CA, C, O) in Å.
    beta_start, beta_end : int
        1‑indexed residue window to examine for β‑sheet content.
    break_thresh : float
        CA–CA distance (Å) over which we consider there to be a chain break.
    min_beta_len : int
        Minimum contiguous length (residues) to acknowledge a β‑strand.

    Notes
    -----
    * **Breaks** – a simple CA distance check between consecutive
      residues.
    * **β‑strand** – we approximate secondary structure purely from
      ϕ/ψ angles (Ramachandran criterion).  Residues with
      ``ϕ ∈ [‑180, ‑30]`` and ``ψ ∈ [90, 180]`` are considered β.
      This is a geometry‑only heuristic that works well for clean models
      without hydrogen‑bond info.
    """

    n_res = bb_tensor.shape[0]
    if n_res < 2:
        return 0, 0

    CA = bb_tensor[:, 1, :]  # [n_res, 3]

    # --- 1. Chain breaks ----------------------------------------------------
    diffs = CA[1:] - CA[:-1]
    dists = torch.linalg.norm(diffs, dim=1)
    num_breaks = int(torch.sum(dists > break_thresh).item())

    # --- 2. β‑strand length -----------------------------------------------
    # Compute φ/ψ angles in degrees
    def dihedral(p0, p1, p2, p3):
        b0 = p1 - p0
        b1 = p2 - p1
        b2 = p3 - p2
        b1 /= torch.linalg.norm(b1)
        v = b0 - (b0 * b1).sum() * b1
        w = b2 - (b2 * b1).sum() * b1
        x = (v * w).sum()
        y = torch.cross(b1, v, dim=0).dot(w)
        return torch.atan2(y, x) * 57.29577951308232  # rad→deg

    phi_psi_mask = torch.zeros(n_res, dtype=torch.bool)
    for i in range(1, n_res - 1):
        phi = dihedral(bb_tensor[i - 1, 2], bb_tensor[i, 0], bb_tensor[i, 1], bb_tensor[i, 2])
        psi = dihedral(bb_tensor[i, 0], bb_tensor[i, 1], bb_tensor[i, 2], bb_tensor[i + 1, 0])
        if -180 <= phi <= -30 and 90 <= psi <= 180:
            phi_psi_mask[i] = True

    window_mask = torch.zeros(n_res, dtype=torch.bool)
    window_mask[beta_start - 1 : min(beta_end, n_res)] = True
    beta_candidates = phi_psi_mask & window_mask

    # longest contiguous run in window
    longest, current = 0, 0
    for flag in beta_candidates.tolist():
        if flag:
            current += 1
            longest = max(longest, current)
        else:
            current = 0

    beta_len = longest if longest >= min_beta_len else 0

    return num_breaks, beta_len
    
# ---------------------------------------------------------------------
#  SMC helper: weight → resample → (optional) mutate
# ---------------------------------------------------------------------
def smc_resample(pop, beta=30.0, inject_sigma=0.0, diffusion_mask=None):
    """
    pop           : list[dict] with keys {x_t, seq_t, px0, score}
    beta          : inverse-temperature converting score to log-weight
    inject_sigma  : base Gaussian noise scale
    diffusion_mask: [L] bool tensor, True=mask (fixed), False=diffusable
    """
    scores = torch.stack([s["score"] for s in pop])
    logw   = -beta * (scores - scores.min())
    w_norm = torch.softmax(logw, dim=0)

    idx = torch.multinomial(w_norm, num_samples=len(pop), replacement=True)
    new_pop = []
    for j in idx:
        s = copy.deepcopy(pop[j])
        if inject_sigma > 0.0:
            if diffusion_mask is not None:
                mask = diffusion_mask.view(-1, 1, 1)  # [L,1,1]
                noise = inject_sigma * torch.randn_like(s["x_t"])
                noise = noise * (~mask).to(noise.dtype)  # zero where mask==True
            else:
                noise = inject_sigma * torch.randn_like(s["x_t"])
            s["x_t"] += noise
            s["px0"] += noise
        new_pop.append(s)
    return new_pop


# --- simple per-timestep seed schedule for Stage 1 ---
def seeds_for_t_stage2(t: int) -> int:
    # t > 30 -> 32 seeds; else 16 seeds
    return 32 if t > 30 else 16

# --- adjust population size using current scores (softmax weights) ---
def adjust_population(evol_states, target_n: int, beta: Optional[float] = None):
    """
    Resize the list of seed states to target_n using weighted sampling.
    If beta is provided, use importance weights ~ softmax(-beta * (score - min)),
    otherwise sample uniformly. Preserves dictionaries via deepcopy.
    """
    cur_n = len(evol_states)
    if cur_n == target_n:
        return evol_states

    import torch, copy
    if beta is not None and cur_n > 0 and evol_states[0]["score"] is not None:
        scores = torch.stack([s["score"] for s in evol_states])
        logw   = -beta * (scores - scores.min())
        w_norm = torch.softmax(logw, dim=0)
    else:
        w_norm = torch.ones(cur_n, dtype=torch.float32) / float(cur_n)

    # shrink without replacement; grow with replacement
    replace = target_n > cur_n
    idx = torch.multinomial(w_norm, num_samples=target_n, replacement=replace)
    return [copy.deepcopy(evol_states[i.item()]) for i in idx]

# ---------------------------------------------------------------------- #

# ---------------------------------------------------------------------
# helper 1: differentiable Kabsch (rigid alignment) -------------------
# aligns Y onto X using the points in the anchor_idx list
def kabsch_align(X: torch.Tensor, Y: torch.Tensor, anchor_idx: torch.Tensor):
    # Accept (R,4,3) or (N,3) transparently
    if Y.ndim == 3:
        Y = Y.reshape(-1, 3)
    if X.ndim == 3:
        X = X.reshape(-1, 3)

    XA = X[anchor_idx] - X[anchor_idx].mean(0, keepdim=True)
    YA = Y[anchor_idx] - Y[anchor_idx].mean(0, keepdim=True)

    C = YA.T @ XA
    U, S, Vt = torch.linalg.svd(C)
    R = Vt.T @ torch.diag(torch.tensor(
            [1, 1, torch.sign(torch.det(Vt.T @ U.T))],
            device=X.device, dtype=X.dtype)) @ U.T
    return (Y - YA.mean(0, keepdim=True)) @ R
# ---------------------------------------------------------------------
def robust_loss(x, delta=0.1):
    abs_x = x.abs()
    quad = (abs_x <= delta)
    return 0.5 * (x**2) * quad + (abs_x - 0.5 * delta) * (~quad)

def stiff_quadratic(x):
    return 0.5 * x**2          # steeper than robust_loss beyond ±delta
# ---------------------------------------------------------------------
def make_primal_func_f(
        ref_xyz: torch.Tensor,
        start_index: int,
        neighbour_pairs: torch.LongTensor,
        bond_pairs: torch.LongTensor,
        huber_delta: float = 0.05,
        bond_weight: float = 100.0,
        init_index: int = 0):               # <<—— 10 → 100 or larger
    """
    Identical call-signature to the earlier make_primal_func but
    *bonds* now use a stiff quadratic loss with a big multiplier.
    """
    ref_dists      = (ref_xyz[neighbour_pairs[:,0]] -
                      ref_xyz[neighbour_pairs[:,1]]).norm(dim=1)
    ref_bond_dists = (ref_xyz[bond_pairs[:,0]] -
                      ref_xyz[bond_pairs[:,1]]).norm(dim=1)

    # anchor_idx = torch.arange(start_index, device=ref_xyz.device) #origin
    anchor_idx = torch.arange(init_index, start_index, device=ref_xyz.device)

    def primal(_, x: torch.Tensor) -> torch.Tensor:
        x  = x.reshape(-1, 3)                  # safety: accept (R,4,3) too
        ref = ref_xyz

        # rigid alignment (unchanged)
        x_aligned = kabsch_align(ref, x, anchor_idx)

        # soft neighbours (Huber)
        cur_dists = (x_aligned[neighbour_pairs[:,0]] -
                     x_aligned[neighbour_pairs[:,1]]).norm(dim=1)
        soft_err  = robust_loss(cur_dists - ref_dists,
                                delta=huber_delta).sum()

        # *hard* bonds (quadratic, heavy weight)
        cur_bond = (x_aligned[bond_pairs[:,0]] -
                    x_aligned[bond_pairs[:,1]]).norm(dim=1)
        bond_err = stiff_quadratic(cur_bond - ref_bond_dists).sum() * bond_weight

        return soft_err + bond_err

    return primal

# ---- chain id parser -------------------------------------------------
# robustly parse the last "Letter(s)+digits[-digits]" token and return the letters
import re
try:
    from omegaconf import ListConfig  # for type-checking OmegaConf lists
except Exception:
    ListConfig = None

def chain_id_from_contigs_list(contigs) -> str:
    if isinstance(contigs, (list, tuple)) or (ListConfig is not None and isinstance(contigs, ListConfig)):
        s = " ".join(map(str, contigs))
    else:
        s = str(contigs)

    s = re.sub(r'/[^,\s\]\)]*', '', s)

    matches = list(re.finditer(r'([A-Za-z]+)\d+(?:-\d+)?', s))
    if not matches:
        raise ValueError(f"Unable to resolve chain ID from contigs: {s!r}")
    return matches[-1].group(1)


# ---- contig helpers -------------------------------------------------
import re
try:
    from omegaconf import ListConfig
except Exception:
    ListConfig = None


def _as_text(contigs) -> str:
    if isinstance(contigs, (list, tuple)) or (ListConfig is not None and isinstance(contigs, ListConfig)):
        return " ".join(map(str, contigs))
    return str(contigs)


def parse_last_range_for_chain(contigs, anchor_chain: str = "A"):
    """
      .. A1-95/6/A102-106/20-40/0 B1-10' -> (102, 106)
    """
    s = _as_text(contigs)
    ranges = []
    for m in re.finditer(r'([A-Za-z]+)(\d+)(?:-(\d+))?', s):
        ch = m.group(1)
        if ch.upper() == anchor_chain.upper():
            start = int(m.group(2))
            end = int(m.group(3) if m.group(3) else m.group(2))
            ranges.append((start, end))
    if not ranges:
        raise ValueError(f"No segment for chain {anchor_chain!r} in contigs: {s!r}")
    return ranges[-1]  # 按出现顺序取最后一个

# --- CA-based structure hash (3 decimals) ---
def _get_ca_hash(x_bb: torch.Tensor) -> int:
    """
    x_bb: [L,4,3] backbone (N,CA,C,O)
    Hash on CA coords rounded to 0.001 Å to tolerate float jitter.
    """
    ca = x_bb[:, 1, :].detach().cpu().numpy().round(decimals=3)
    return hash(ca.tobytes())


@hydra.main(version_base=None,
            config_path="../config/inference", config_name="base")
def main(conf):

    log = logging.getLogger(__name__)
    if conf.inference.deterministic:
        make_deterministic(conf.inference.seed)

    # ------------------------------------------------------------------ #
    #  sampler, peptide chain, neighbourhood topology
    # ------------------------------------------------------------------ #
    sampler     = iu.sampler_selector(conf)          # RFdiffusion wrapper
    spec = conf.contigmap.contigs
    chain_id = chain_id_from_contigs_list(spec)
    print(f"[DEBUG] parsed chain_id from contigs: {chain_id}")
    p_chain = extract_backbone_tensor(conf.projection.peptide_path, chain_id=chain_id)
    #new

    import re

    def find_anchor_from_contigs(contigs, anchor_chain='A') -> int:
        
        s = ' '.join(contigs) if isinstance(contigs, (list, tuple)) else str(contigs)
        s = s.strip().strip('[]')  

        m = re.search(rf'/\s*\d+\s*/\s*{anchor_chain}(\d+)(?:-\d+)?', s)
        if m:
            return int(m.group(1))

        m = re.search(rf'/\s*{anchor_chain}(\d+)(?:-\d+)?', s)
        if m:
            return int(m.group(1))

        occ = re.findall(rf'{anchor_chain}(\d+)(?:-\d+)?', s)
        if len(occ) >= 2:
            return int(occ[1])
        if occ:
            return int(occ[0])

        raise ValueError(f"Anchor like '/<num>/{anchor_chain}<start>' not found in contigs: {contigs}")


    spec = conf.contigmap.contigs
    start_index = find_anchor_from_contigs(conf.contigmap.contigs, anchor_chain='A') #new
    print(f"[DEBUG] start_index={start_index}")

    # start_index = 94                                 # PDZ tail anchor residue

    # one **deterministic** sample_init to set contig lengths & mask
    base_xt, base_seq = sampler.sample_init()        # shape (L,14,3)  L == mask
    L = base_xt.shape[0]
    neigh_pairs, bond_pairs = build_chain_pairs(L, atoms_per_res=4, include_next2=True)
    diffusion_mask = None
    start_time = time.time()

    evol_seeds    = conf.projection.evol_seeds
    evol_interval = conf.projection.evol_interval
    t_start       = int(conf.diffuser.T)
    t_stop        = sampler.inf_conf.final_step      # usually 1

    # ------------------------------------------------------------------ #
    #  main design loop (usually num_designs==1)
    # ------------------------------------------------------------------ #
    design_start = (sampler.inf_conf.design_startnum
                    if sampler.inf_conf.design_startnum != -1 else 0)
    for d_i in range(design_start,
                     design_start + sampler.inf_conf.num_designs):

        out_prefix = f"{sampler.inf_conf.output_prefix}_{d_i}"
        if sampler.inf_conf.cautious and os.path.exists(out_prefix + ".pdb"):
            log.info(f"(cautious) {out_prefix}.pdb exists – skipping.")
            continue

        log.info(f"Design {d_i}  |  mask length = {L}")

        # Seed-specific RNG once; **clone** the same base tensors
        evol_states = []
        for s_i in range(evol_seeds):
            make_deterministic(conf.inference.seed + s_i)
            evol_states.append({
                "x_t":   base_xt.clone(),       # same length as mask
                "seq_t": base_seq.clone(),
                "px0":   None,
                "score": torch.tensor(float("inf")),
            })

        # trajectory collectors
        denoised_xyz_stack, px0_xyz_stack = [], []
        seq_stack,          plddt_stack   = [], []

        # ------------------------------------------------------------------
        #  reverse-diffusion with evolutionary sampling
        # ------------------------------------------------------------------
        t = t_start
        while t >= t_stop:
            # --- A. run evol_interval steps for all seeds   (t > start_step)
            if t > sampler.inf_conf.final_step:
                for i_seed, s in enumerate(evol_states):
                        
                    make_deterministic(conf.inference.seed + i_seed)
                        
                    xt, seqt = s["x_t"], s["seq_t"]
                    px0, xt1, seqt, plddt = sampler.sample_step(
                        t=t, x_t=xt, seq_init=seqt,
                        final_step=sampler.inf_conf.final_step
                    )
                    
                    s["x_t"], s["seq_t"], s["px0"] = (
                        xt1.detach(), seqt.detach(), px0.detach()
                    )
                    
                
                # Record constraint violation
                with torch.no_grad():
                    for s in evol_states:
                        xbb = s["px0"][:, :4]
                        a_v = torch.stack([
                            angle_penalty_pdz(xbb, p_chain, offset=o,
                                              start_index=start_index)
                            for o in range(conf.projection.n_constraints)]).sum() # new
                        d_v = torch.stack([
                            distance_penalty_pdz(xbb, p_chain, offset=o,
                                                 start_index=start_index)
                            for o in range(conf.projection.n_constraints)]).sum() #new

                
                        s["score"] = a_v + d_v 



                # Record gap induced
                with torch.no_grad():
                    
                    # inverse‑temperature β_t anneals with noise (higher when t is small)
                    beta_t = conf.projection.smc_beta0 * (int(sampler.t_step_input) - t + 1) \
                             / int(sampler.t_step_input)

                    # # --- apply simple Stage-1 seed schedule at this timestep ---
                    # target_n = seeds_for_t_stage2(t)
                    # if len(evol_states) != target_n:
                    #     evol_states = adjust_population(evol_states, target_n=target_n, beta=float(beta_t))

                    # get the denoiser instance
                    denoiser = sampler.denoiser  

                    # pull the CA‐atom noise scale at t and t–1
                    σ_ca_t   = denoiser.noise_schedule_ca(t)
                    σ_ca_t1  = denoiser.noise_schedule_ca(t-1)

                    # pull the frame noise scale at t and t–1 if you want to mix them
                    σ_fr_t   = denoiser.noise_schedule_frame(t)
                    σ_fr_t1  = denoiser.noise_schedule_frame(t-1)

                    # pick whichever is most representative (often the CA noise is what you want)
                    sigma_t = float(σ_ca_t)  

                    print(f"t={t}: σ_ca={σ_ca_t:.3f}, σ_frame={σ_fr_t:.3f}")   
                    
                    # print diagnostics
                    best = min([s["score"] for s in evol_states]).item()
                    worst = max([s["score"] for s in evol_states]).item()
                    print(f"t={t:3d}  β={beta_t:4.1f}  best={best:8.4f}  worst={worst:8.4f}")

                    evol_states = smc_resample(
                        evol_states,
                        beta=beta_t,
                        inject_sigma=0.0 * sigma_t,            # noise may not help
                        diffusion_mask=diffusion_mask
                    )

                
                
                projection_cache = {}   # {hash: (x_t_clone, px0_clone)}
                   
                for i_seed, s in enumerate(evol_states):
                    
                    xt, seqt, px0 = s["x_t"], s["seq_t"], s["px0"]
                    h = _get_ca_hash(px0[:, :4])
                    if h in projection_cache:
                        cached_xt, cached_px0 = projection_cache[h]
                        print(f"[INFO] t={t} seed {i_seed}: duplicate structure, reuse cached projection.")
                        s["x_t"] = cached_xt.clone()
                        s["px0"] = cached_px0.clone()
                        s["seq_t"] = seqt  
                        continue
                    
                    # optional final-stage projection on PDZ tail
                    xi, flat_ref = px0[:, :4], px0[:, :4].reshape(-1, 3)
                    start_res, end_res = parse_last_range_for_chain(spec,  anchor_chain='A')
                    print("start_res:", start_res, "end_res:", end_res)
                    primal_func = make_primal_func_f(
                        ref_xyz=flat_ref,
                        init_index=start_res*4,
                        start_index=(end_res+1)*4,
                        neighbour_pairs=neigh_pairs,
                        bond_pairs=bond_pairs,
                    )
                    px0[:, :4], _ = projection_pdz(
                        target_rep=xi, p_chain=p_chain,
                        start_index=start_index, primal_func=primal_func,
                        tol=torch.tensor(1e-5), gap_pairs_ra=None,
                        n_constraints=conf.projection.n_constraints, hot_start=xi, verbose=False, #new
                        first_movable_res=start_res-4, last_movable_res=end_res+2,
                        gap_penalty=False
                    )

                    xt1, _ = sampler.denoiser.get_next_pose(
                        xt=xt, px0=px0, t=t,
                        diffusion_mask=sampler.mask_str.squeeze(),
                        align_motif=sampler.inf_conf.align_motif,
                        include_motif_sidechains=sampler.preprocess_conf.motif_sidechain_input,
                    )
                    
                    s["x_t"], s["seq_t"], s["px0"] = (
                        xt1.detach(), seqt.detach(), px0.detach()
                    )
                    projection_cache[h] = (s["x_t"].clone(), s["px0"].clone())
                 
                

                        

                if t < t_stop:
                    break


            # --- C. once we leave the evol window, pick the survivor
            survivor = evol_states[int(torch.argmin(
            torch.stack([s["score"] for s in evol_states])))]
            x_t   = survivor["x_t"]
            seq_t = survivor["seq_t"]
            px0   = survivor["px0"]

            
            t -= 1
            
            # -----------------------------  end reverse loop  --------------

            px0_xyz_stack.append(px0)
            denoised_xyz_stack.append(x_t)
            seq_stack.append(seq_t)
            plddt_stack.append(plddt[0])  # remove singleton leading dimension

        # ----- save outputs  (unchanged from your original script) ------
        # Flip order for better visualization in pymol
        denoised_xyz_stack = torch.stack(denoised_xyz_stack)
        denoised_xyz_stack = torch.flip(
            denoised_xyz_stack,
            [
                0,
            ],
        )
        px0_xyz_stack = torch.stack(px0_xyz_stack)
        px0_xyz_stack = torch.flip(
            px0_xyz_stack,
            [
                0,
            ],
        )

        # For logging -- don't flip
        plddt_stack = torch.stack(plddt_stack)

        
        os.makedirs(os.path.dirname(out_prefix), exist_ok=True)
        final_seq = torch.where(
            torch.argmax(base_seq, dim=-1) == 21, 7, torch.argmax(base_seq, dim=-1)
        )
        bfacts = torch.ones_like(final_seq.squeeze())
        bfacts[torch.where(torch.argmax(base_seq, dim=-1) == 21, True, False)] = 0

        writepdb(
            f"{out_prefix}.pdb",
            px0_xyz_stack[0][:, :4],
            final_seq,
            sampler.binderlen,
            chain_idx=sampler.chain_idx,
            bfacts=bfacts,
        )

        trb = dict(
            config=OmegaConf.to_container(sampler._conf, resolve=True),
            plddt=plddt_stack,
            device=torch.cuda.get_device_name(torch.cuda.current_device())
            if torch.cuda.is_available() else "CPU",
            time=time.time() - start_time,
        )
        if hasattr(sampler, "contig_map"):
            trb.update(sampler.contig_map.get_mappings())
        with open(f"{out_prefix}.trb", "wb") as fh:
            pickle.dump(trb, fh)

        if sampler.inf_conf.write_trajectory:
            traj_prefix = os.path.join(os.path.dirname(out_prefix), "traj",
                                        os.path.basename(out_prefix))
            os.makedirs(os.path.dirname(traj_prefix), exist_ok=True)

            writepdb_multi(f"{traj_prefix}_Xt-1_traj.pdb",
                            denoised_xyz_stack, bfacts, final_seq.squeeze(),
                            use_hydrogens=False, backbone_only=False,
                            chain_ids=sampler.chain_idx)
            writepdb_multi(f"{traj_prefix}_pX0_traj.pdb",
                            px0_xyz_stack, bfacts, final_seq.squeeze(),
                            use_hydrogens=False, backbone_only=False,
                            chain_ids=sampler.chain_idx)

    log.info(f"Design done in {(time.time()-start_time)/60:.2f} min")

if __name__ == "__main__":
    main()
