#!/usr/bin/env python
"""
Inference script.

To run with base.yaml as the config,

> python run_inference.py

To specify a different config,

> python run_inference.py --config-name symmetry

where symmetry can be the filename of any other config (without .yaml extension)
See https://hydra.cc/docs/advanced/hydra-command-line-flags/ for more options.

"""

import re
import os, time, pickle
os.environ["TORCH_CPP_LOG_LEVEL"] = "ERROR"
import warnings
warnings.filterwarnings("ignore")
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
import torch
from omegaconf import OmegaConf
import hydra
import logging
from rfdiffusion.util import writepdb_multi, writepdb
from rfdiffusion.inference import utils as iu
from hydra.core.hydra_config import HydraConfig
import numpy as np
import random
import glob
import math
import csv

from constraints.projection import *

from typing import Tuple
import copy


# ---------------------------------------------------------------------    
def make_deterministic(seed=0):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
# ---------------------------------------------------------------------
#  SMC helper: weight → resample → (optional) mutate
# ---------------------------------------------------------------------
def smc_resample(pop, beta=30.0, inject_sigma=0.0, diffusion_mask=None):
    """
    pop           : list[dict] with keys {x_t, seq_t, px0, score}
    beta          : inverse-temperature converting score to log-weight
    inject_sigma  : base Gaussian noise scale
    diffusion_mask: [L] bool tensor, True=mask (fixed), False=diffusable
    """
    scores = torch.stack([s["score"] for s in pop])
    logw   = -beta * (scores - scores.min())
    w_norm = torch.softmax(logw, dim=0)

    idx = torch.multinomial(w_norm, num_samples=len(pop), replacement=True)
    new_pop = []
    for j in idx:
        s = copy.deepcopy(pop[j])
        if inject_sigma > 0.0:
            # noise only on diffusable residues
            mask = diffusion_mask.view(-1, 1, 1)  # [L,1,1]
            noise = inject_sigma * torch.randn_like(s["x_t"])
            # zero out noise where mask==True
            noise = noise * (~mask).to(noise.dtype)
            s["x_t"] += noise
            s["px0"] += noise
        new_pop.append(s)
    return new_pop
# ---------------------------------------------------------------------        
def log_violation(angle_viols, dist_viols):
    data = angle_viols + dist_viols
    with open('violation.csv', 'w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(['Index', 'Value'])
        for i, val in enumerate(data):
            writer.writerow([i, val])
# ---------------------------------------------------------------------
# build neighbour / bond topology from chain layout -------------------
def build_chain_pairs(n_res: int,
                      atoms_per_res: int = 4,
                      include_next2: bool = False
                      ) -> Tuple[torch.LongTensor, torch.LongTensor]:
    """
    Returns (neighbour_pairs, bond_pairs) as (M×2) index tensors.

    * 'bond_pairs'  = covalent links (weight them strongly).
    * 'neighbour_pairs' includes bond_pairs plus any extra “soft”
      neighbours (e.g. two residues apart if include_next2=True).

    Indexing is into the *flattened* (n_res*atoms_per_res) array, i.e.
        flat_idx = res_idx * atoms_per_res + atom_idx
    """
    pairs, bonds = [], []

    for r in range(n_res):
        base = r * atoms_per_res

        # inside one residue: 0-1, 1-2, 2-3
        for a in range(atoms_per_res - 1):
            i, j = base + a, base + a + 1
            pairs.append((i, j))
            bonds.append((i, j))

        # link residue r to r+1: atom 3 -> atom 0
        if r < n_res - 1:
            i, j = base + atoms_per_res - 1, (r + 1) * atoms_per_res
            pairs.append((i, j))
            bonds.append((i, j))

        # optional “second-next” neighbour for smoother backbone
        if include_next2 and r < n_res - 2:
            i, j = base + atoms_per_res - 1, (r + 2) * atoms_per_res
            pairs.append((i, j))                     # NOT a bond

    neigh = torch.tensor(pairs, dtype=torch.long)
    bond  = torch.tensor(bonds, dtype=torch.long)
    return neigh, bond
# ---------------------------------------------------------------------

def robust_loss(x, delta=0.1):
    abs_x = x.abs()
    quad = (abs_x <= delta)
    return 0.5 * (x**2) * quad + (abs_x - 0.5 * delta) * (~quad)

def stiff_quadratic(x):
    return 0.5 * x**2          # steeper than robust_loss beyond ±delta
# ---------------------------------------------------------------------

import re
from typing import Sequence
try:
    from omegaconf import ListConfig  # 若环境没有也不报错
except Exception:
    ListConfig = tuple()  # 占位

def chain_id_from_contigs_list(contigs: Sequence[str]) -> str:
    """
    更鲁棒的链ID解析：
    - 允许 contigs 是 ListConfig/list 或被错误地当作一个字符串
    - 忽略 '/...' 尾巴
    - 全局搜 '[A-Za-z]+\\d+(?:-\\d+)?'，取最后一个片段的字母作为链ID
    """
    # 统一拿到一个字符串视图
    if isinstance(contigs, (list, tuple)) or (ListConfig and isinstance(contigs, ListConfig)):
        s = " ".join(map(str, contigs))
    else:
        s = str(contigs)

    # 去掉 '/...'（例如 '/0'）
    s = re.sub(r'/[^,\s\]\)]*', '', s)

    # 全局找所有“字母+数字(可选 -数字)”片段
    matches = list(re.finditer(r'([A-Za-z]+)\d+(?:-\d+)?', s))
    if not matches:
        raise ValueError(f"Unable to resolve chain ID from contigs：{s!r}")

    # 取最后一个匹配的片段
    return matches[-1].group(1)

@hydra.main(version_base=None, config_path="../config/inference", config_name="base")
def main(conf: HydraConfig) -> None:
    log = logging.getLogger(__name__)
    if conf.inference.deterministic:
        make_deterministic()
        
    

    # Check for available GPU and print result of check
    if torch.cuda.is_available():
        device_name = torch.cuda.get_device_name(torch.cuda.current_device())
        log.info(f"Found GPU with device_name {device_name}. Will run RFdiffusion on {device_name}")
    else:
        log.info("////////////////////////////////////////////////")
        log.info("///// NO GPU DETECTED! Falling back to CPU /////")
        log.info("////////////////////////////////////////////////")

    # Initialize sampler and target/contig.
    sampler = iu.sampler_selector(conf)
    
    # Early stopping
    sampler.inf_conf.final_step += 1
    

    # Loop over number of designs to sample.
    design_startnum = sampler.inf_conf.design_startnum
    if sampler.inf_conf.design_startnum == -1:
        existing = glob.glob(sampler.inf_conf.output_prefix + "*.pdb")
        indices = [-1]
        for e in existing:
            print(e)
            m = re.match(".*_(\d+)\.pdb$", e)
            print(m)
            if not m:
                continue
            m = m.groups()[0]
            indices.append(int(m))
        design_startnum = max(indices) + 1

    for i_des in range(design_startnum, design_startnum + sampler.inf_conf.num_designs):

        evol_states = []      # list[dict] – one per seed
        for i_seed in range(conf.projection.evol_seeds):
            make_deterministic(conf.inference.seed + i_seed)        # or any scheme
            x0, seq0 = sampler.sample_init()
            evol_states.append({
                "x_t":   torch.clone(x0),
                "seq_t": torch.clone(seq0),
                "px0": None,
                "score": torch.tensor(float("inf"))
            })


        start_time = time.time()
        out_prefix = f"{sampler.inf_conf.output_prefix}_{i_des}"
        log.info(f"Making design {out_prefix}")
        if sampler.inf_conf.cautious and os.path.exists(out_prefix + ".pdb"):
            log.info(
                f"(cautious mode) Skipping this design because {out_prefix}.pdb already exists."
            )
            continue

            
        x_init, seq_init   = sampler.sample_init()
        denoised_xyz_stack = []
        px0_xyz_stack      = []
        seq_stack          = []
        plddt_stack        = []

                
        x_t         = torch.clone(x_init)        
        seq_t       = torch.clone(seq_init)
        # p_chain     = extract_backbone_tensor(conf.projection.peptide_path,chain_id='P') #new
        spec = conf.contigmap.contigs
        chain_id = chain_id_from_contigs_list(spec)
        print(f"[DEBUG] parsed chain_id from contigs: {chain_id}")
        p_chain     = extract_backbone_tensor(conf.projection.peptide_path, chain_id=chain_id)
        start_index = conf.projection.start_index

        
        L = x_init.shape[0]          # total number of residues
        # start with all fixed:
        # diffusion_mask = torch.ones(L, dtype=torch.bool, device=x_init.device)
        # # free only the first 5 positions:
        # diffusion_mask[:5] = False
        diffusion_mask = sampler.mask_str.squeeze().to(x_init.device)
        print("[DEBUG] diffusion_mask shape:", diffusion_mask.shape)
        print("[DEBUG] num_fixed (True)   =", int(diffusion_mask.sum().item()))
        print("[DEBUG] num_movable (False)=", int((~diffusion_mask).sum().item()))
        num_movable = int((~diffusion_mask).sum().item())
                
        # Evolutionary sampling initialization
        eval_interval = conf.projection.evol_interval
        t_current     = int(sampler.t_step_input)

        while t_current >= sampler.inf_conf.final_step:
            # -----------------------------------------------------------
            # 3a Run 'eval_interval' reverse steps for *each* seed
            # -----------------------------------------------------------
            
            for _ in range(eval_interval):
                for i_seed, s in enumerate(evol_states):
                        
                    make_deterministic(conf.inference.seed + i_seed)
                    xt, seqt = s["x_t"], s["seq_t"]

                    px0, xt_1, seqt, plddt = sampler.sample_step(
                        t=t_current, x_t=xt, seq_init=seqt,
                        final_step=sampler.inf_conf.final_step
                    )
                    # keep state for scoring
                    s["px0"] = px0.detach()

                    # advance the trajectory
                    s["x_t"], s["seq_t"] = xt_1.detach(), seqt.detach()
                    #new
                    s["plddt"] = plddt.detach() 

                t_current -= 1
                if t_current < sampler.inf_conf.final_step:
                    break
            
                    
                
            # 3b  Score every seed & run SMC resample
            with torch.no_grad():
                for s in evol_states:
                    xbb = s["px0"][:, :4]
                    angle_v = torch.stack([angle_penalty_pdz(
                        xbb, p_chain, offset=o, start_index=start_index) for o in range(conf.projection.n_constraints)]).sum()
                    dist_v  = torch.stack([distance_penalty_pdz(
                        xbb, p_chain, offset=o, start_index=start_index) for o in range(conf.projection.n_constraints)]).sum()
                    s["score"] = angle_v + dist_v

            # inverse‑temperature β_t anneals with noise (higher when t is small)
            beta_t = conf.projection.smc_beta0 * (int(sampler.t_step_input) - t_current + 1) \
                     / int(sampler.t_step_input)

            # get the denoiser instance
            denoiser = sampler.denoiser  

            # pull the CA‐atom noise scale at t and t–1
            σ_ca_t   = denoiser.noise_schedule_ca(t_current)

            # pull the frame noise scale at t and t–1 if you want to mix them
            σ_fr_t   = denoiser.noise_schedule_frame(t_current)

            # pick whichever is most representative (often the CA noise is what you want)
            sigma_t = float(σ_ca_t)  
                
            print(f"t={t_current}: σ_ca={σ_ca_t:.3f}, σ_frame={σ_fr_t:.3f}")                

            evol_states = smc_resample(
                evol_states,
                beta=beta_t,
                inject_sigma=0.0 * sigma_t,            # noise may not help
                diffusion_mask=diffusion_mask
            )

            # print diagnostics
            best = min([s["score"] for s in evol_states]).item()
            worst = max([s["score"] for s in evol_states]).item()
            print(f"t={t_current:3d}  β={beta_t:4.1f}  best={best:8.4f}  worst={worst:8.4f}")

                                                        
            # px0_xyz_stack.append(px0)
            # denoised_xyz_stack.append(x_t)
            # seq_stack.append(seq_t)
            # plddt_stack.append(plddt[0])  # remove singleton leading dimension
            scores = torch.stack([s["score"] for s in evol_states])
            best_idx = torch.argmin(scores).item()
            best = evol_states[best_idx]

            px0_xyz_stack.append(best["px0"])   # 回看本步用于打分的去噪结果
            denoised_xyz_stack.append(best["x_t"])  # 轨迹里的 x_t
            seq_stack.append(best["seq_t"])
            plddt_stack.append(best["plddt"])  
            
        
        # Flip order for better visualization in pymol
        denoised_xyz_stack = torch.stack(denoised_xyz_stack)
        denoised_xyz_stack = torch.flip(
            denoised_xyz_stack,
            [
                0,
            ],
        )
        px0_xyz_stack = torch.stack(px0_xyz_stack)
        px0_xyz_stack = torch.flip(
            px0_xyz_stack,
            [
                0,
            ],
        )

        # For logging -- don't flip
        plddt_stack = torch.stack(plddt_stack)

        # Save outputs
        os.makedirs(os.path.dirname(out_prefix), exist_ok=True)
        final_seq = seq_stack[-1]

        # Output glycines, except for motif region
        final_seq = torch.where(
            torch.argmax(seq_init, dim=-1) == 21, 7, torch.argmax(seq_init, dim=-1)
        )  # 7 is glycine

        bfacts = torch.ones_like(final_seq.squeeze())
        # make bfact=0 for diffused coordinates
        bfacts[torch.where(torch.argmax(seq_init, dim=-1) == 21, True, False)] = 0
        # pX0 last step
        out = f"{out_prefix}.pdb"
        
        
        # ------------------------------------------------------------------
        #  Save only top-K *unique* particles (duplication test via Ca-RMSD)
        # ------------------------------------------------------------------
        rmsd_thr = getattr(conf.projection, "dup_rmsd", 1.0)   # Å
        K        = getattr(conf.projection, "save_topk", 3)    # default 3

        def ca_rmsd(x, y):
            return torch.sqrt(((x[:,1,:] - y[:,1,:])**2).mean())  # Ca only

        # sort by score (ascending = better)
        particles_sorted = sorted(evol_states, key=lambda p: p["score"].item())

        picked = []
        for cand in particles_sorted:
            xyz_cand = cand["x_t"][:,:4]          # backbone
            # Check Ca RMSD to all previously picked structures
            is_dup = False
            for p in picked:
                if ca_rmsd(xyz_cand, p["x_t"][:,:4]) < rmsd_thr:
                    is_dup = True
                    break
            if not is_dup:
                picked.append(cand)
            if len(picked) >= K: break


        os.makedirs(os.path.dirname(out_prefix), exist_ok=True)

        for rank, p in enumerate(picked):
            fname = f"{out_prefix}_seed{rank}.pdb"
            seq_final = torch.argmax(p["seq_t"], -1)
            bfacts = torch.ones_like(seq_final).float()
            writepdb(fname, p["x_t"][:,:4], seq_final, sampler.binderlen,
                     chain_idx=sampler.chain_idx, bfacts=bfacts)


        # run metadata
        trb = dict(
            config=OmegaConf.to_container(sampler._conf, resolve=True),
            plddt=plddt_stack.cpu().numpy(),
            device=torch.cuda.get_device_name(torch.cuda.current_device())
            if torch.cuda.is_available()
            else "CPU",
            time=time.time() - start_time,
        )
        if hasattr(sampler, "contig_map"):
            for key, value in sampler.contig_map.get_mappings().items():
                trb[key] = value
        with open(f"{out_prefix}.trb", "wb") as f_out:
            pickle.dump(trb, f_out)

        if sampler.inf_conf.write_trajectory:
            # trajectory pdbs
            traj_prefix = (
                os.path.dirname(out_prefix) + "/traj/" + os.path.basename(out_prefix)
            )
            os.makedirs(os.path.dirname(traj_prefix), exist_ok=True)

            out = f"{traj_prefix}_Xt-1_traj.pdb"
            writepdb_multi(
                out,
                denoised_xyz_stack,
                bfacts,
                final_seq.squeeze(),
                use_hydrogens=False,
                backbone_only=False,
                chain_ids=sampler.chain_idx,
            )

            out = f"{traj_prefix}_pX0_traj.pdb"
            writepdb_multi(
                out,
                px0_xyz_stack,
                bfacts,
                final_seq.squeeze(),
                use_hydrogens=False,
                backbone_only=False,
                chain_ids=sampler.chain_idx,
            )

        log.info(f"Finished design in {(time.time()-start_time)/60:.2f} minutes")


if __name__ == "__main__":
    main()
