#!/usr/bin/env python
"""
Evolutionary PDZ-tail inference (final-projection stage).

Add to your YAML (if not already present):
projection:
  evol_seeds:     10      # number of parallel seeds
  evol_interval:   5      # evaluate / clone every k reverse steps
  start_step:     25      # last timestep included in evolutionary phase
"""

import os, re, time, pickle, random, glob, csv, logging
os.environ["TORCH_CPP_LOG_LEVEL"] = "ERROR"
import warnings
warnings.filterwarnings("ignore")
import math
import numpy as np
import torch
from omegaconf import OmegaConf
import hydra
import copy
from typing import Optional
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

# ---- all helper functions / classes live in your init script ----------
from scripts_uncond.uncond_init import *   #  projection_pdz, penalties, etc.

# ---------------------------------------------------------------------- #
#  reproducibility helper
# ---------------------------------------------------------------------- #
def make_deterministic(seed=0):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    
    
def breaks_and_beta_from_tensor(
    bb_tensor: torch.Tensor,
    beta_start: int = 90,
    beta_end: int = 100,
    break_thresh: float = 4.5,
    min_beta_len: int = 4,
) -> Tuple[int, int]:
    """Return *(num_breaks, β‑strand length in window)* from a backbone tensor.

    Parameters
    ----------
    bb_tensor : torch.Tensor
        Shape ``[n_res, 4, 3]`` (N, CA, C, O) in Å.
    beta_start, beta_end : int
        1‑indexed residue window to examine for β‑sheet content.
    break_thresh : float
        CA–CA distance (Å) over which we consider there to be a chain break.
    min_beta_len : int
        Minimum contiguous length (residues) to acknowledge a β‑strand.

    Notes
    -----
    * **Breaks** – a simple CA distance check between consecutive
      residues.
    * **β‑strand** – we approximate secondary structure purely from
      ϕ/ψ angles (Ramachandran criterion).  Residues with
      ``ϕ ∈ [‑180, ‑30]`` and ``ψ ∈ [90, 180]`` are considered β.
      This is a geometry‑only heuristic that works well for clean models
      without hydrogen‑bond info.
    """

    n_res = bb_tensor.shape[0]
    if n_res < 2:
        return 0, 0

    CA = bb_tensor[:, 1, :]  # [n_res, 3]

    # --- 1. Chain breaks ----------------------------------------------------
    diffs = CA[1:] - CA[:-1]
    dists = torch.linalg.norm(diffs, dim=1)
    num_breaks = int(torch.sum(dists > break_thresh).item())

    # --- 2. β‑strand length -----------------------------------------------
    # Compute φ/ψ angles in degrees
    def dihedral(p0, p1, p2, p3):
        b0 = p1 - p0
        b1 = p2 - p1
        b2 = p3 - p2
        b1 /= torch.linalg.norm(b1)
        v = b0 - (b0 * b1).sum() * b1
        w = b2 - (b2 * b1).sum() * b1
        x = (v * w).sum()
        y = torch.cross(b1, v, dim=0).dot(w)
        return torch.atan2(y, x) * 57.29577951308232  # rad→deg

    phi_psi_mask = torch.zeros(n_res, dtype=torch.bool)
    for i in range(1, n_res - 1):
        phi = dihedral(bb_tensor[i - 1, 2], bb_tensor[i, 0], bb_tensor[i, 1], bb_tensor[i, 2])
        psi = dihedral(bb_tensor[i, 0], bb_tensor[i, 1], bb_tensor[i, 2], bb_tensor[i + 1, 0])
        if -180 <= phi <= -30 and 90 <= psi <= 180:
            phi_psi_mask[i] = True

    window_mask = torch.zeros(n_res, dtype=torch.bool)
    window_mask[beta_start - 1 : min(beta_end, n_res)] = True
    beta_candidates = phi_psi_mask & window_mask

    # longest contiguous run in window
    longest, current = 0, 0
    for flag in beta_candidates.tolist():
        if flag:
            current += 1
            longest = max(longest, current)
        else:
            current = 0

    beta_len = longest if longest >= min_beta_len else 0

    return num_breaks, beta_len
    
# ---------------------------------------------------------------------
#  SMC helper: weight → resample → (optional) mutate
# ---------------------------------------------------------------------
def smc_resample(pop, beta=30.0, inject_sigma=0.0, diffusion_mask=None):
    """
    pop           : list[dict] with keys {x_t, seq_t, px0, score}
    beta          : inverse-temperature converting score to log-weight
    inject_sigma  : base Gaussian noise scale
    diffusion_mask: [L] bool tensor, True=mask (fixed), False=diffusable
    """
    scores = torch.stack([s["score"] for s in pop])
    logw   = -beta * (scores - scores.min())
    w_norm = torch.softmax(logw, dim=0)

    idx = torch.multinomial(w_norm, num_samples=len(pop), replacement=True)
    new_pop = []
    for j in idx:
        s = copy.deepcopy(pop[j])
        if inject_sigma > 0.0:
            # noise only on diffusable residues
            mask = diffusion_mask.view(-1, 1, 1)  # [L,1,1]
            noise = inject_sigma * torch.randn_like(s["x_t"])
            # zero out noise where mask==True
            noise = noise * (~mask).to(noise.dtype)
            s["x_t"] += noise
            s["px0"] += noise
        new_pop.append(s)
    return new_pop
# ---------------------------------------------------------------------- #
# robustly parse the last "Letter(s)+digits[-digits]" token and return the letters
import re
try:
    from omegaconf import ListConfig  # for type-checking OmegaConf lists
except Exception:
    ListConfig = None

def chain_id_from_contigs_list(contigs) -> str:
    if isinstance(contigs, (list, tuple)) or (ListConfig is not None and isinstance(contigs, ListConfig)):
        s = " ".join(map(str, contigs))
    else:
        s = str(contigs)

    s = re.sub(r'/[^,\s\]\)]*', '', s)

    matches = list(re.finditer(r'([A-Za-z]+)\d+(?:-\d+)?', s))
    if not matches:
        raise ValueError(f"Unable to resolve chain ID from contigs: {s!r}")
    return matches[-1].group(1)
# ---------------------------------------------------------------------- #
@hydra.main(version_base=None,
            config_path="../config/inference", config_name="base")
def main(conf):

    log = logging.getLogger(__name__)
    if conf.inference.deterministic:
        make_deterministic(conf.inference.seed)

    # ------------------------------------------------------------------ #
    #  sampler, peptide chain, neighbourhood topology
    # ------------------------------------------------------------------ #
    sampler     = iu.sampler_selector(conf)          # RFdiffusion wrapper
    # p_chain     = extract_backbone_tensor(conf.projection.peptide_path)
    spec = conf.contigmap.contigs
    chain_id = chain_id_from_contigs_list(spec)
    print(f"[DEBUG] parsed chain_id from contigs: {chain_id}")
    p_chain = extract_backbone_tensor(conf.projection.peptide_path, chain_id=chain_id)
    # start_index = 94                                 # PDZ tail anchor residue
    import re

    def find_anchor_from_contigs(contigs, anchor_chain='A') -> int:
        
        s = ' '.join(contigs) if isinstance(contigs, (list, tuple)) else str(contigs)
        s = s.strip().strip('[]')  

        m = re.search(rf'/\s*\d+\s*/\s*{anchor_chain}(\d+)(?:-\d+)?', s)
        if m:
            return int(m.group(1))

        m = re.search(rf'/\s*{anchor_chain}(\d+)(?:-\d+)?', s)
        if m:
            return int(m.group(1))

        occ = re.findall(rf'{anchor_chain}(\d+)(?:-\d+)?', s)
        if len(occ) >= 2:
            return int(occ[1])
        if occ:
            return int(occ[0])

        raise ValueError(f"Anchor like '/<num>/{anchor_chain}<start>' not found in contigs: {contigs}")


    spec = conf.contigmap.contigs
    start_index = find_anchor_from_contigs(conf.contigmap.contigs, anchor_chain='A') #new
    print(f"[DEBUG] start_index={start_index}")

    # one **deterministic** sample_init to set contig lengths & mask
    base_xt, base_seq = sampler.sample_init()        # shape (L,14,3)  L == mask
    L = base_xt.shape[0]
    neigh_pairs, bond_pairs = build_chain_pairs(L, atoms_per_res=4, include_next2=True)
    diffusion_mask = None
    start_time = time.time()

    evol_seeds    = conf.projection.evol_seeds
    evol_interval = conf.projection.evol_interval
    t_start       = int(conf.diffuser.T)
    t_stop        = sampler.inf_conf.final_step      # usually 1

    # ------------------------------------------------------------------ #
    #  main design loop (usually num_designs==1)
    # ------------------------------------------------------------------ #
    design_start = (sampler.inf_conf.design_startnum
                    if sampler.inf_conf.design_startnum != -1 else 0)
    for d_i in range(design_start,
                     design_start + sampler.inf_conf.num_designs):

        out_prefix = f"{sampler.inf_conf.output_prefix}_{d_i}"
        if sampler.inf_conf.cautious and os.path.exists(out_prefix + ".pdb"):
            log.info(f"(cautious) {out_prefix}.pdb exists – skipping.")
            continue

        log.info(f"Design {d_i}  |  mask length = {L}")

        # Seed-specific RNG once; **clone** the same base tensors
        evol_states = []
        for s_i in range(evol_seeds):
            make_deterministic(conf.inference.seed + s_i)
            evol_states.append({
                "x_t":   base_xt.clone(),       # same length as mask
                "seq_t": base_seq.clone(),
                "px0":   None,
                "score": torch.tensor(float("inf")),
            })

        # trajectory collectors
        denoised_xyz_stack, px0_xyz_stack = [], []
        seq_stack,          plddt_stack   = [], []

        # ------------------------------------------------------------------
        #  reverse-diffusion with evolutionary sampling
        # ------------------------------------------------------------------
        t = t_start
        while t >= t_stop:
            # --- A. run evol_interval steps for all seeds   (t > start_step)
            if t > sampler.inf_conf.final_step:
                for i_seed, s in enumerate(evol_states):
                        
                    make_deterministic(conf.inference.seed + i_seed)
                        
                    xt, seqt = s["x_t"], s["seq_t"]
                    px0, xt1, seqt, plddt = sampler.sample_step(
                        t=t, x_t=xt, seq_init=seqt,
                        final_step=sampler.inf_conf.final_step
                    )
                    
                    s["x_t"], s["seq_t"], s["px0"] = (
                        xt1.detach(), seqt.detach(), px0.detach()
                    )
                    
                
                # Record constraint violation
                with torch.no_grad():
                    for s in evol_states:
                        xbb = s["px0"][:, :4]
                        a_v = torch.stack([
                            angle_penalty_pdz(xbb, p_chain, offset=o,
                                              start_index=start_index)
                            for o in range(conf.projection.n_constraints)]).sum()
                        d_v = torch.stack([
                            distance_penalty_pdz(xbb, p_chain, offset=o,
                                                 start_index=start_index)
                            for o in range(conf.projection.n_constraints)]).sum()

                
                        s["score"] = a_v + d_v
                 
                # Record gap induced
                with torch.no_grad():
                    
                    # inverse‑temperature β_t anneals with noise (higher when t is small)
                    beta_t = conf.projection.smc_beta0 * (int(sampler.t_step_input) - t + 1) \
                             / int(sampler.t_step_input)

                    # get the denoiser instance
                    denoiser = sampler.denoiser  

                    # pull the CA‐atom noise scale at t and t–1
                    σ_ca_t   = denoiser.noise_schedule_ca(t)

                    # pull the frame noise scale at t and t–1 if you want to mix them
                    σ_fr_t   = denoiser.noise_schedule_frame(t)

                    # pick whichever is most representative (often the CA noise is what you want)
                    sigma_t = float(σ_ca_t)  

                    print(f"t={t}: σ_ca={σ_ca_t:.3f}, σ_frame={σ_fr_t:.3f}")   
                    
                    # print diagnostics
                    best = min([s["score"] for s in evol_states]).item()
                    worst = max([s["score"] for s in evol_states]).item()
                    print(f"t={t:3d}  β={beta_t:4.1f}  best={best:8.4f}  worst={worst:8.4f}")

                    evol_states = smc_resample(
                        evol_states,
                        beta=beta_t,
                        inject_sigma=0.0 * sigma_t,            # noise may not help
                        diffusion_mask=diffusion_mask
                    )

                        

                if t < t_stop:
                    break


            # --- C. once we leave the evol window, pick the survivor
            survivor = evol_states[int(torch.argmin(
            torch.stack([s["score"] for s in evol_states])))]
            x_t   = survivor["x_t"]
            seq_t = survivor["seq_t"]
            px0   = survivor["px0"]

            
            t -= 1
            
            # -----------------------------  end reverse loop  --------------

            px0_xyz_stack.append(px0)
            denoised_xyz_stack.append(x_t)
            seq_stack.append(seq_t)
            plddt_stack.append(plddt[0])  # remove singleton leading dimension

        # ----- save outputs  (unchanged from your original script) ------
        # Flip order for better visualization in pymol
        denoised_xyz_stack = torch.stack(denoised_xyz_stack)
        denoised_xyz_stack = torch.flip(
            denoised_xyz_stack,
            [
                0,
            ],
        )
        px0_xyz_stack = torch.stack(px0_xyz_stack)
        px0_xyz_stack = torch.flip(
            px0_xyz_stack,
            [
                0,
            ],
        )

        # For logging -- don't flip
        plddt_stack = torch.stack(plddt_stack)

        
        os.makedirs(os.path.dirname(out_prefix), exist_ok=True)
        final_seq = torch.where(
            torch.argmax(base_seq, dim=-1) == 21, 7, torch.argmax(base_seq, dim=-1)
        )
        bfacts = torch.ones_like(final_seq.squeeze())
        bfacts[torch.where(torch.argmax(base_seq, dim=-1) == 21, True, False)] = 0

        writepdb(
            f"{out_prefix}.pdb",
            px0_xyz_stack[0][:, :4],
            final_seq,
            sampler.binderlen,
            chain_idx=sampler.chain_idx,
            bfacts=bfacts,
        )

        trb = dict(
            config=OmegaConf.to_container(sampler._conf, resolve=True),
            plddt=plddt_stack,
            device=torch.cuda.get_device_name(torch.cuda.current_device())
            if torch.cuda.is_available() else "CPU",
            time=time.time() - start_time,
        )
        if hasattr(sampler, "contig_map"):
            trb.update(sampler.contig_map.get_mappings())
        with open(f"{out_prefix}.trb", "wb") as fh:
            pickle.dump(trb, fh)

        if sampler.inf_conf.write_trajectory:
            traj_prefix = os.path.join(os.path.dirname(out_prefix), "traj",
                                        os.path.basename(out_prefix))
            os.makedirs(os.path.dirname(traj_prefix), exist_ok=True)

            writepdb_multi(f"{traj_prefix}_Xt-1_traj.pdb",
                            denoised_xyz_stack, bfacts, final_seq.squeeze(),
                            use_hydrogens=False, backbone_only=False,
                            chain_ids=sampler.chain_idx)
            writepdb_multi(f"{traj_prefix}_pX0_traj.pdb",
                            px0_xyz_stack, bfacts, final_seq.squeeze(),
                            use_hydrogens=False, backbone_only=False,
                            chain_ids=sampler.chain_idx)

    log.info(f"Design done in {(time.time()-start_time)/60:.2f} min")

if __name__ == "__main__":
    main()
