#!/usr/bin/env python
"""
Inference script.

To run with base.yaml as the config,

> python run_inference.py

To specify a different config,

> python run_inference.py --config-name symmetry

where symmetry can be the filename of any other config (without .yaml extension)
See https://hydra.cc/docs/advanced/hydra-command-line-flags/ for more options.

"""

import re
import os, time, pickle
os.environ["TORCH_CPP_LOG_LEVEL"] = "ERROR"
import warnings
warnings.filterwarnings("ignore")
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
import torch
from omegaconf import OmegaConf
import hydra
import logging
from rfdiffusion.util import writepdb_multi, writepdb
from rfdiffusion.inference import utils as iu
from hydra.core.hydra_config import HydraConfig
import numpy as np
import random
import glob
import math
import csv


from constraints.constraint import *
from constraints.projection import projection_pdz

from typing import Tuple, Optional
import copy


# ---------------------------------------------------------------------    
def make_deterministic(seed=0):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
# ---------------------------------------------------------------------
#  SMC helper: weight → resample → (optional) mutate
# ---------------------------------------------------------------------
def smc_resample(pop, beta=30.0, inject_sigma=0.0, diffusion_mask=None):
    """
    pop           : list[dict] with keys {x_t, seq_t, px0, score}
    beta          : inverse-temperature converting score to log-weight
    inject_sigma  : base Gaussian noise scale
    diffusion_mask: [L] bool tensor, True=mask (fixed), False=diffusable
    """
    scores = torch.stack([s["score"] for s in pop])
    logw   = -beta * (scores - scores.min())
    w_norm = torch.softmax(logw, dim=0)

    idx = torch.multinomial(w_norm, num_samples=len(pop), replacement=True)
    new_pop = []
    for j in idx:
        s = copy.deepcopy(pop[j])
        if inject_sigma > 0.0:
            if diffusion_mask is not None:
                mask = diffusion_mask.view(-1, 1, 1)  # [L,1,1]
                noise = inject_sigma * torch.randn_like(s["x_t"])
                noise = noise * (~mask).to(noise.dtype)  # zero where mask==True
            else:
                noise = inject_sigma * torch.randn_like(s["x_t"])
            s["x_t"] += noise
            s["px0"] += noise
        new_pop.append(s)
    return new_pop
# ---------------------------------------------------------------------        
def log_violation(angle_viols, dist_viols):
    data = angle_viols + dist_viols
    with open('violation.csv', 'w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(['Index', 'Value'])
        for i, val in enumerate(data):
            writer.writerow([i, val])
# ---------------------------------------------------------------------
# build neighbour / bond topology from chain layout -------------------
def build_chain_pairs(n_res: int,
                      atoms_per_res: int = 4,
                      include_next2: bool = False
                      ) -> Tuple[torch.LongTensor, torch.LongTensor]:
    """
    Returns (neighbour_pairs, bond_pairs) as (M×2) index tensors.

    * 'bond_pairs'  = covalent links (weight them strongly).
    * 'neighbour_pairs' includes bond_pairs plus any extra “soft”
      neighbours (e.g. two residues apart if include_next2=True).

    Indexing is into the *flattened* (n_res*atoms_per_res) array, i.e.
        flat_idx = res_idx * atoms_per_res + atom_idx
    """
    pairs, bonds = [], []

    for r in range(n_res):
        base = r * atoms_per_res

        # inside one residue: 0-1, 1-2, 2-3
        for a in range(atoms_per_res - 1):
            i, j = base + a, base + a + 1
            pairs.append((i, j))
            bonds.append((i, j))

        # link residue r to r+1: atom 3 -> atom 0
        if r < n_res - 1:
            i, j = base + atoms_per_res - 1, (r + 1) * atoms_per_res
            pairs.append((i, j))
            bonds.append((i, j))

        # optional “second-next” neighbour for smoother backbone
        if include_next2 and r < n_res - 2:
            i, j = base + atoms_per_res - 1, (r + 2) * atoms_per_res
            pairs.append((i, j))                     # NOT a bond

    neigh = torch.tensor(pairs, dtype=torch.long)
    bond  = torch.tensor(bonds, dtype=torch.long)
    return neigh, bond
# ---------------------------------------------------------------------
def build_gap_pairs_ra(
    n_res: int,
    first_res: int = 0,
    atoms_per_res: int = 4,
) -> torch.LongTensor:
    """
    Build comprehensive constraints including bonds, angles, and dihedrals.
    Returns dict with keys: 'bonds', 'angles', 'dihedrals'
    """
    bonds = []
    angles = []
    dihedrals = []
    
    # Bridge constraints if needed
    if first_res > 0:
        bonds.append((first_res - 1, 2, first_res, 0, 1))  # C-N peptide
    
    for r in range(first_res, n_res):
        # BOND constraints
        bonds += [
            (r, 0, r, 1, 0),  # N-CA (covalent)
            (r, 1, r, 2, 0),  # CA-C (covalent)
            (r, 2, r, 3, 0),  # C-O (covalent)
        ]
        
        # ANGLE constraints (res1, atom1, res2, atom2, res3, atom3, type)
        # Key backbone angles that maintain peptide geometry
        angles += [
            (r, 0, r, 1, r, 2, 0),  # N-CA-C angle (~109°)
            (r, 1, r, 2, r, 3, 1),  # CA-C-O angle (~121°)
        ]
        
        # Inter-residue connections
        if r + 1 < n_res:
            bonds.append((r, 2, r + 1, 0, 1))  # C-N peptide
            
            # Critical angles for peptide geometry
            angles += [
                (r, 1, r, 2, r + 1, 0, 2),      # CA-C-N angle (~117°)
                (r, 2, r + 1, 0, r + 1, 1, 3),  # C-N-CA angle (~122°)
            ]
            
            # DIHEDRAL constraints (res1, atom1, res2, atom2, res3, atom3, res4, atom4, type)
            # Omega dihedral (peptide planarity) - should be ~180° or ~0°
            if r >= first_res:
                dihedrals.append((r, 1, r, 2, r + 1, 0, r + 1, 1, 0))  # CA-C-N-CA (omega)
    
    return {
        'bonds': torch.tensor(bonds, dtype=torch.long) if bonds else torch.empty((0, 5), dtype=torch.long),
        'angles': torch.tensor(angles, dtype=torch.long) if angles else torch.empty((0, 7), dtype=torch.long),
        'dihedrals': torch.tensor(dihedrals, dtype=torch.long) if dihedrals else torch.empty((0, 9), dtype=torch.long)
    }



# ---------------------------------------------------------------------
# helper 1: differentiable Kabsch (rigid alignment) -------------------
# aligns Y onto X using the points in the anchor_idx list
def kabsch_align(X: torch.Tensor, Y: torch.Tensor, anchor_idx: torch.Tensor):
    # Accept (R,4,3) or (N,3) transparently
    if Y.ndim == 3:
        Y = Y.reshape(-1, 3)
    if X.ndim == 3:
        X = X.reshape(-1, 3)

    XA = X[anchor_idx] - X[anchor_idx].mean(0, keepdim=True)
    YA = Y[anchor_idx] - Y[anchor_idx].mean(0, keepdim=True)

    C = YA.T @ XA
    U, S, Vt = torch.linalg.svd(C)
    R = Vt.T @ torch.diag(torch.tensor(
            [1, 1, torch.sign(torch.det(Vt.T @ U.T))],
            device=X.device, dtype=X.dtype)) @ U.T
    return (Y - YA.mean(0, keepdim=True)) @ R
# ---------------------------------------------------------------------
def robust_loss(x, delta=0.1):
    abs_x = x.abs()
    quad = (abs_x <= delta)
    return 0.5 * (x**2) * quad + (abs_x - 0.5 * delta) * (~quad)

def stiff_quadratic(x):
    return 0.5 * x**2          # steeper than robust_loss beyond ±delta
# ---------------------------------------------------------------------
def make_primal_func(
        ref_xyz: torch.Tensor,
        start_index: int,
        neighbour_pairs: torch.LongTensor,
        bond_pairs: torch.LongTensor,
        huber_delta: float = 0.05,
        bond_weight: float = 100.0):               # <<—— 10 → 100 or larger
    """
    Identical call-signature to the earlier make_primal_func but
    *bonds* now use a stiff quadratic loss with a big multiplier.
    """
    ref_dists      = (ref_xyz[neighbour_pairs[:,0]] -
                      ref_xyz[neighbour_pairs[:,1]]).norm(dim=1)
    ref_bond_dists = (ref_xyz[bond_pairs[:,0]] -
                      ref_xyz[bond_pairs[:,1]]).norm(dim=1)

    anchor_idx = torch.arange(start_index, device=ref_xyz.device)

    def primal(_, x: torch.Tensor) -> torch.Tensor:
        x  = x.reshape(-1, 3)                  # safety: accept (R,4,3) too
        ref = ref_xyz

        # rigid alignment (unchanged)
        x_aligned = kabsch_align(ref, x, anchor_idx)

        # soft neighbours (Huber)
        cur_dists = (x_aligned[neighbour_pairs[:,0]] -
                     x_aligned[neighbour_pairs[:,1]]).norm(dim=1)
        soft_err  = robust_loss(cur_dists - ref_dists,
                                delta=huber_delta).sum()

        # *hard* bonds (quadratic, heavy weight)
        cur_bond = (x_aligned[bond_pairs[:,0]] -
                    x_aligned[bond_pairs[:,1]]).norm(dim=1)
        bond_err = stiff_quadratic(cur_bond - ref_bond_dists).sum() * bond_weight

        return soft_err + bond_err

    return primal

# ---------------------------------------------------------------------    
class ConstraintSchedule:
    def __init__(
        self,
        angle_max,
        angle_min,
        distance_max,
        distance_min,
        timesteps=50,
        angle_schedule_type='linear',   # Options: 'linear', 'cosine', etc.
        distance_schedule_type='linear' # Options: 'linear', 'cosine', etc.
    ):
        self.angle_max = angle_max
        self.angle_min = angle_min
        self.distance_max = distance_max
        self.distance_min = distance_min
        self.timesteps = timesteps
        self.angle_schedule_type = angle_schedule_type
        self.distance_schedule_type = distance_schedule_type

    def _interpolate(self, min_val, max_val, t, T, mode='linear'):
        """Interpolate between min_val and max_val based on t/T."""
        fraction = np.clip(t / T, 0, 1)
        if mode == 'linear':
            return min_val + (max_val - min_val) * fraction
        elif mode == 'cosine':
            # Cosine annealing, slow at first, fast at end.
            cos_frac = (1 - np.cos(np.pi * fraction)) / 2
            return min_val + (max_val - min_val) * cos_frac
        else:
            raise ValueError(f"Unknown mode: {mode}")

    def get_tolerance(self, t):
        """
        Returns (angle_tolerance, distance_tolerance) at timestep t.
        At t=0: returns min, at t=timesteps: returns max.
        """
        angle_tol = self._interpolate(
            self.angle_min, self.angle_max, t, self.timesteps,
            mode=self.angle_schedule_type
        )
        dist_tol = self._interpolate(
            self.distance_min, self.distance_max, t, self.timesteps,
            mode=self.distance_schedule_type
        )
        return angle_tol, dist_tol

# --- simple per-timestep seed schedule for Stage 1 ---
def seeds_for_t_stage1(t: int) -> int:
    # t > 20 -> 16 seeds; else 8 seeds
    return 16 if t > 20 else 8

# --- adjust population size using current scores (softmax weights) ---
def adjust_population(evol_states, target_n: int, beta: Optional[float] = None):
    """
    Resize the list of seed states to target_n using weighted sampling.
    If beta is provided, use importance weights ~ softmax(-beta * (score - min)),
    otherwise sample uniformly. Preserves dictionaries via deepcopy.
    """
    cur_n = len(evol_states)
    if cur_n == target_n:
        return evol_states

    import torch, copy
    if beta is not None and cur_n > 0 and evol_states[0]["score"] is not None:
        scores = torch.stack([s["score"] for s in evol_states])
        logw   = -beta * (scores - scores.min())
        w_norm = torch.softmax(logw, dim=0)
    else:
        w_norm = torch.ones(cur_n, dtype=torch.float32) / float(cur_n)

    # shrink without replacement; grow with replacement
    replace = target_n > cur_n
    idx = torch.multinomial(w_norm, num_samples=target_n, replacement=replace)
    return [copy.deepcopy(evol_states[i.item()]) for i in idx]

# ---------------------------------------------------------------------    

import re
from typing import Sequence
try:
    from omegaconf import ListConfig 
except Exception:
    ListConfig = tuple()  

def chain_id_from_contigs_list(contigs: Sequence[str]) -> str:

    if isinstance(contigs, (list, tuple)) or (ListConfig and isinstance(contigs, ListConfig)):
        s = " ".join(map(str, contigs))
    else:
        s = str(contigs)

    s = re.sub(r'/[^,\s\]\)]*', '', s)

    matches = list(re.finditer(r'([A-Za-z]+)\d+(?:-\d+)?', s))
    if not matches:
        raise ValueError(f"Unable to resolve chain ID from contigs：{s!r}")

    return matches[-1].group(1)


    

@hydra.main(version_base=None, config_path="../config/inference", config_name="base")
def main(conf: HydraConfig) -> None:
    log = logging.getLogger(__name__)
    if conf.inference.deterministic:
        make_deterministic()
    
    
    # Configure constraint
    if   conf.projection.penalty_func == 'b':
        penalty_func = lambda x, s: joint_penalty(x, start_index=s)
        
    elif conf.projection.penalty_func == 'd':
        penalty_func = lambda x, s: distance_penalty(x, start_index=s)
    
    elif conf.projection.penalty_func == 'a':
        penalty_func = lambda x, s: angle_penalty(x, start_index=s)
        
    else:
        raise NotImplementedError()

    # Check for available GPU and print result of check
    if torch.cuda.is_available():
        device_name = torch.cuda.get_device_name(torch.cuda.current_device())
        log.info(f"Found GPU with device_name {device_name}. Will run RFdiffusion on {device_name}")
    else:
        log.info("////////////////////////////////////////////////")
        log.info("///// NO GPU DETECTED! Falling back to CPU /////")
        log.info("////////////////////////////////////////////////")

    # Initialize sampler and target/contig.
    sampler = iu.sampler_selector(conf)
    
    # Early stopping
    sampler.inf_conf.final_step += 1
    

    # Loop over number of designs to sample.
    design_startnum = sampler.inf_conf.design_startnum
    if sampler.inf_conf.design_startnum == -1:
        existing = glob.glob(sampler.inf_conf.output_prefix + "*.pdb")
        indices = [-1]
        for e in existing:
            print(e)
            m = re.match(".*_(\d+)\.pdb$", e)
            print(m)
            if not m:
                continue
            m = m.groups()[0]
            indices.append(int(m))
        design_startnum = max(indices) + 1

    for i_des in range(design_startnum, design_startnum + sampler.inf_conf.num_designs):

        evol_states = []      # list[dict] – one per seed
        for i_seed in range(conf.projection.evol_seeds):
            make_deterministic(conf.inference.seed + i_seed)        # or any scheme
            x0, seq0 = sampler.sample_init()
            evol_states.append({
                "x_t":   torch.clone(x0),
                "seq_t": torch.clone(seq0),
                "px0": None,
                "score": torch.tensor(float("inf"))
            })


        start_time = time.time()
        out_prefix = f"{sampler.inf_conf.output_prefix}_{i_des}"
        log.info(f"Making design {out_prefix}")
        if sampler.inf_conf.cautious and os.path.exists(out_prefix + ".pdb"):
            log.info(
                f"(cautious mode) Skipping this design because {out_prefix}.pdb already exists."
            )
            continue

            
        x_init, seq_init   = sampler.sample_init()
        denoised_xyz_stack = []
        px0_xyz_stack      = []
        seq_stack          = []
        plddt_stack        = []

                
        x_t         = torch.clone(x_init)        
        seq_t       = torch.clone(seq_init)
        # print("peptide_path:", conf.projection.peptide_path)
        # print("input_pdb:", conf.inference.input_pdb)
        spec = conf.contigmap.contigs
        chain_id = chain_id_from_contigs_list(spec)
        print(f"[DEBUG] parsed chain_id from contigs: {chain_id}")
        p_chain     = extract_backbone_tensor(conf.projection.peptide_path, chain_id=chain_id)  #new with parsing from contig

        start_index = conf.projection.start_index
        hot_start   = False
        debug       = True
        
        L = x_init.shape[0]         

        #new
        diffusion_mask = sampler.mask_str.squeeze().to(x_init.device)
        print("[DEBUG] diffusion_mask shape:", diffusion_mask.shape)
        print("[DEBUG] num_fixed (True)   =", int(diffusion_mask.sum().item()))
        print("[DEBUG] num_movable (False)=", int((~diffusion_mask).sum().item()))
        num_movable = int((~diffusion_mask).sum().item())

        
        # neighbour & bond lists for the whole chain
        neigh_pairs, bond_pairs = build_chain_pairs(
                n_res         = x_t.shape[0],
                atoms_per_res = 4,
                include_next2 = True,
        )
    
        # Initializations for projections
        sched = ConstraintSchedule(
            angle_max=2.0, angle_min=0.0, distance_max=2.0, distance_min=0.0, timesteps=50,
            angle_schedule_type='linear', distance_schedule_type='linear'
        )
        
                
        # Evolutionary sampling initialization
        eval_interval = conf.projection.evol_interval
        t_current     = int(sampler.t_step_input)

        while t_current >= sampler.inf_conf.final_step:
            # -----------------------------------------------------------
            # 3a Run 'eval_interval' reverse steps for *each* seed
            # -----------------------------------------------------------
            
            for _ in range(eval_interval):
                for i_seed, s in enumerate(evol_states):
                        
                    make_deterministic(conf.inference.seed + i_seed)
                    xt, seqt = s["x_t"], s["seq_t"]

                    px0, xt_1, seqt, plddt = sampler.sample_step(
                        t=t_current, x_t=xt, seq_init=seqt,
                        final_step=sampler.inf_conf.final_step
                    )
 
                    # keep state for scoring
                    s["px0"] = px0.detach()

                    # advance the trajectory
                    s["x_t"], s["seq_t"] = xt_1.detach(), seqt.detach()

                t_current -= 1
                if t_current < sampler.inf_conf.final_step:
                    break
                    
                
            # 3b  Score every seed & run SMC resample
            with torch.no_grad():
                for s in evol_states:
                    xbb = s["px0"][:, :4]
                    angle_v = torch.stack([angle_penalty_pdz(
                        xbb, p_chain, offset=o, start_index=start_index) for o in range(conf.projection.n_constraints)]).sum() #new
                    dist_v  = torch.stack([distance_penalty_pdz(
                        xbb, p_chain, offset=o, start_index=start_index) for o in range(conf.projection.n_constraints)]).sum() #new
                    s["score"] = angle_v + dist_v

            # inverse‑temperature β_t anneals with noise (higher when t is small)
            beta_t = conf.projection.smc_beta0 * (int(sampler.t_step_input) - t_current + 1) \
                     / int(sampler.t_step_input)
            
            # --- apply simple seed schedule at this timestep ---
            # target_n = seeds_for_t_stage1(t_current)
            # if len(evol_states) != target_n:
            #     evol_states = adjust_population(evol_states, target_n=target_n, beta=float(beta_t))

            # get the denoiser instance
            denoiser = sampler.denoiser  

            # pull the CA‐atom noise scale at t and t–1
            σ_ca_t   = denoiser.noise_schedule_ca(t_current)
            σ_ca_t1  = denoiser.noise_schedule_ca(t_current-1)

            # pull the frame noise scale at t and t–1 if you want to mix them
            σ_fr_t   = denoiser.noise_schedule_frame(t_current)
            σ_fr_t1  = denoiser.noise_schedule_frame(t_current-1)

            # pick whichever is most representative (often the CA noise is what you want)
            sigma_t = float(σ_ca_t)  
                
            print(f"t={t_current}: σ_ca={σ_ca_t:.3f}, σ_frame={σ_fr_t:.3f}")                

            evol_states = smc_resample(
                evol_states,
                beta=beta_t,
                inject_sigma=0.0 * sigma_t,            # noise may not help
                diffusion_mask=diffusion_mask
            )

            # print diagnostics
            best = min([s["score"] for s in evol_states]).item()
            worst = max([s["score"] for s in evol_states]).item()
            print(f"t={t_current:3d}  β={beta_t:4.1f}  best={best:8.4f}  worst={worst:8.4f}")

                                                
            # project current states
            if t_current < conf.projection.start_step:
                
                # initialize
                projection_cache = {}

                def get_ca_hash(x):
                    ca_coords = x[:, 1, :].detach().cpu().numpy().round(decimals=3)
                    return hash(ca_coords.tobytes())

                for i_seed, s in enumerate(evol_states):
                    px0 = s["px0"]  # shape [L, 4, 3]
                    tolerances = sched.get_tolerance(t_current)

                    # generate hash
                    h = get_ca_hash(px0[:, :4])

                    # see if existing
                    if h in projection_cache:
                        print(f"[INFO] Seed {i_seed}: skipping projection (duplicate); using cached result")
                        s["x_t"], s["px0"] = projection_cache[h]
                        continue

                    # projection
                    xi = px0[:, :4]
                    flat_ref = xi.reshape(-1, 3)

                    primal_func = make_primal_func(
                        ref_xyz=flat_ref,
                        start_index=5 * 4, #87
                        neighbour_pairs=neigh_pairs,
                        bond_pairs=bond_pairs,
                    )

                    gap_pairs = build_gap_pairs_ra(
                        n_res=start_index + 2,
                        first_res=start_index,
                    )

                    px0[:, :4], verification = projection_pdz(
                        target_rep=xi,
                        p_chain=p_chain,
                        start_index=start_index,
                        first_movable_res=0,
                        last_movable_res=num_movable,
                        primal_func=primal_func,
                        tol=tolerances[0],
                        gap_pairs_ra=gap_pairs,
                        n_constraints=conf.projection.n_constraints,
                        hot_start=xi,
                        verbose=False,
                    )




                    ###########################
                    # Convert P_C(px0) to x_t #
                    ###########################

                   

                    if  t_current >= sampler.inf_conf.final_step: 

                        x_t, _ = sampler.denoiser.get_next_pose(
                            xt=x_t,
                            px0=px0,
                            t=t_current,
                            diffusion_mask=sampler.mask_str.squeeze(),
                            align_motif=sampler.inf_conf.align_motif,
                            include_motif_sidechains=sampler.preprocess_conf.motif_sidechain_input
                        )

                    else:

                        x_t[:, :4], _ = projection_pdz(
                            target_rep            = px0[:, :4],
                            p_chain               = p_chain,
                            start_index           = start_index,
                            first_movable_res     = 0,                     #new
                            last_movable_res      = num_movable,                         #new
                            primal_func           = primal_func,
                            tol                   = 0.0,
                            gap_pairs_ra          = gap_pairs,
                            n_constraints         = conf.projection.n_constraints, #new
                            hot_start             = px0[:, :4],
                            verbose               = True,
                        )

                   
                    # keep state for scoring
                    s["px0"] = px0.detach()

                    # advance the trajectory
                    s["x_t"], s["seq_t"] = x_t.detach(), seqt.detach()

                    projection_cache[h] = (s["x_t"], s["px0"])


                
            
            
            px0_xyz_stack.append(px0)
            denoised_xyz_stack.append(x_t)
            seq_stack.append(seq_t)
            plddt_stack.append(plddt[0])  # remove singleton leading dimension
            

            
        violation_total  = torch.stack([angle_penalty_pdz(x_t[:, :4], p_chain, offset=0, start_index=i) for i in range(conf.projection.n_constraints)]).sum() #new_no_use
        violation_total += torch.stack([distance_penalty_pdz(x_t[:, :4], p_chain, offset=0, start_index=i) for i in range(conf.projection.n_constraints)]).sum() #new_no_use
        
        
        
        # Flip order for better visualization in pymol
        denoised_xyz_stack = torch.stack(denoised_xyz_stack)
        denoised_xyz_stack = torch.flip(
            denoised_xyz_stack,
            [
                0,
            ],
        )
        px0_xyz_stack = torch.stack(px0_xyz_stack)
        px0_xyz_stack = torch.flip(
            px0_xyz_stack,
            [
                0,
            ],
        )

        # For logging -- don't flip
        plddt_stack = torch.stack(plddt_stack)

        # Save outputs
        os.makedirs(os.path.dirname(out_prefix), exist_ok=True)
        final_seq = seq_stack[-1]

        # Output glycines, except for motif region
        final_seq = torch.where(
            torch.argmax(seq_init, dim=-1) == 21, 7, torch.argmax(seq_init, dim=-1)
        )  # 7 is glycine

        bfacts = torch.ones_like(final_seq.squeeze())
        # make bfact=0 for diffused coordinates
        bfacts[torch.where(torch.argmax(seq_init, dim=-1) == 21, True, False)] = 0
        # pX0 last step
        out = f"{out_prefix}.pdb"
        
        
        # ------------------------------------------------------------------
        #  Save only top-K *unique* particles (duplication test via Ca-RMSD)
        # ------------------------------------------------------------------
        rmsd_thr = getattr(conf.projection, "dup_rmsd", 1.0)   # Å
        K        = getattr(conf.projection, "save_topk", 3)    # default 3

        def ca_rmsd(x, y):
            return torch.sqrt(((x[:,1,:] - y[:,1,:])**2).mean())  # Ca only

        # sort by score (ascending = better)
        particles_sorted = sorted(evol_states, key=lambda p: p["score"].item())

        picked = []
        for cand in particles_sorted:
            xyz_cand = cand["x_t"][:,:4]          # backbone
            # Check Ca RMSD to all previously picked structures
            is_dup = False
            for p in picked:
                if ca_rmsd(xyz_cand, p["x_t"][:,:4]) < rmsd_thr:
                    is_dup = True
                    break
            if not is_dup:
                picked.append(cand)
            if len(picked) >= K: break


        os.makedirs(os.path.dirname(out_prefix), exist_ok=True)

        for rank, p in enumerate(picked):
            fname = f"{out_prefix}_seed{rank}.pdb"
            seq_final = torch.argmax(p["seq_t"], -1)
            bfacts = torch.ones_like(seq_final).float()
            writepdb(fname, p["x_t"][:,:4], seq_final, sampler.binderlen,
                     chain_idx=sampler.chain_idx, bfacts=bfacts)


        # run metadata
        trb = dict(
            config=OmegaConf.to_container(sampler._conf, resolve=True),
            plddt=plddt_stack.cpu().numpy(),
            device=torch.cuda.get_device_name(torch.cuda.current_device())
            if torch.cuda.is_available()
            else "CPU",
            time=time.time() - start_time,
        )
        if hasattr(sampler, "contig_map"):
            for key, value in sampler.contig_map.get_mappings().items():
                trb[key] = value
        with open(f"{out_prefix}.trb", "wb") as f_out:
            pickle.dump(trb, f_out)

        if sampler.inf_conf.write_trajectory:
            # trajectory pdbs
            traj_prefix = (
                os.path.dirname(out_prefix) + "/traj/" + os.path.basename(out_prefix)
            )
            os.makedirs(os.path.dirname(traj_prefix), exist_ok=True)

            out = f"{traj_prefix}_Xt-1_traj.pdb"
            writepdb_multi(
                out,
                denoised_xyz_stack,
                bfacts,
                final_seq.squeeze(),
                use_hydrogens=False,
                backbone_only=False,
                chain_ids=sampler.chain_idx,
            )

            out = f"{traj_prefix}_pX0_traj.pdb"
            writepdb_multi(
                out,
                px0_xyz_stack,
                bfacts,
                final_seq.squeeze(),
                use_hydrogens=False,
                backbone_only=False,
                chain_ids=sampler.chain_idx,
            )

        log.info(f"Finished design in {(time.time()-start_time)/60:.2f} minutes")


if __name__ == "__main__":
    main()
