#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
pdz_map.py — Flexible residue/chain remapper for PDB

Usage:
  python scripts_ad/pdz_map.py   in.pdb  out.pdb  --map "A:1-5->A:94-98; B:6-92->A:1-87; B:93-99->P:103-109"

Notes:
- The --map string is a semicolon-separated list of mapping rules.
- Each rule is:  SRC_CHAIN:LO-HI -> DST_CHAIN:LO-HI
- Ranges are **inclusive**. Single numbers like "7" are allowed (treated as "7-7").
- Length (LO-HI) on the LHS must equal length on the RHS; otherwise, an error will be reported.
- (chain, resSeq) not matched by any rule will be left as is (chain and order remain unchanged).
- Output: Sort by (new_chain, new_resSeq, original_order); write TER whenever the chain switches; write END at the end.
"""

import argparse
import re
from typing import List, Tuple, Optional, Dict

# ---------------------------
# PDB helpers (fixed columns)
# ---------------------------

def is_atom_record(line: str) -> bool:
    rec = line[:6]
    return rec.startswith("ATOM") or rec.startswith("HETATM")

def get_chain_resseq(line: str) -> Tuple[str, int]:
    """
    Extract (chain_id, resSeq) from a PDB ATOM/HETATM line.
    Chain: column 22 (0-based idx 21); resSeq: columns 23–26 (idx 22–26)
    """
    chain_id = line[21]
    resseq_str = line[22:26]
    try:
        resseq = int(resseq_str)
    except ValueError:
        resseq = int(resseq_str.strip())
    return chain_id, resseq

def set_chain_resseq(line: str, new_chain: str, new_resseq: int) -> str:
    """
    Replace chain (col 22) and resSeq (cols 23–26) in a PDB line.
    """
    if len(line) < 26:
        return line
    new_chain_char = new_chain[:1] if new_chain else " "
    new_resseq_str = f"{new_resseq:4d}"
    return f"{line[:21]}{new_chain_char}{new_resseq_str}{line[26:]}"

# ---------------------------
# Mapping DSL
# ---------------------------

Range = Tuple[int, int]  # inclusive
Rule  = Tuple[str, Range, str, Range]  # (src_chain,(lo,hi), dst_chain,(lo,hi))

def parse_num_or_range(spec: str) -> Range:
    """Parse '7' or '1-7' → (7,7) 或 (1,7)。"""
    spec = spec.strip()
    if "-" in spec:
        lo_s, hi_s = spec.split("-", 1)
        lo, hi = int(lo_s), int(hi_s)
    else:
        lo = hi = int(spec)
    if lo > hi:
        raise ValueError(f"Invalid range '{spec}': lo > hi")
    return (lo, hi)

def parse_one_rule(text: str) -> Rule:
    """
    'A:1-7->A:120-126' → ('A',(1,7),'A',(120,126))
    """
    text = text.strip()
    if not text:
        raise ValueError("Empty rule")

    m = re.match(r"^\s*([A-Za-z])\s*:\s*([0-9\-]+)\s*->\s*([A-Za-z])\s*:\s*([0-9\-]+)\s*$", text)
    if not m:
        raise ValueError(f"Bad rule format: '{text}' (expect 'A:1-7->A:120-126')")
    src_chain, src_rng, dst_chain, dst_rng = m.groups()
    src = parse_num_or_range(src_rng)
    dst = parse_num_or_range(dst_rng)

    # Length consistency check
    src_len = src[1] - src[0] + 1
    dst_len = dst[1] - dst[0] + 1
    if src_len != dst_len:
        raise ValueError(f"Length mismatch in rule '{text}': src({src_len}) != dst({dst_len})")
    return (src_chain, src, dst_chain, dst)

def parse_rules(dsl: str) -> List[Rule]:
    """
    'A:1-7->A:120-126; B:8-95->A:1-88; B:96-102->P:200-206'
    """
    parts = [p for p in dsl.split(";") if p.strip()]
    rules = [parse_one_rule(p) for p in parts]
    # Check for interval overlap (on the same src_chain)
    check_overlap(rules)
    return rules

def check_overlap(rules: List[Rule]) -> None:
    """
    Prevent src intervals from overlapping on the same src_chain
    """
    by_chain: Dict[str, List[Range]] = {}
    for src_chain, (lo, hi), _, _ in rules:
        by_chain.setdefault(src_chain, []).append((lo, hi))
    for ch, ranges in by_chain.items():
        ranges_sorted = sorted(ranges)
        prev_lo, prev_hi = None, None
        for lo, hi in ranges_sorted:
            if prev_hi is not None and lo <= prev_hi:
                raise ValueError(f"Overlapping source ranges on chain {ch}: [{prev_lo},{prev_hi}] vs [{lo},{hi}]")
            prev_lo, prev_hi = lo, hi

def build_mapper(rules: List[Rule]):
    """
    Returns a function mapper(chain,resSeq) -> (new_chain,new_resSeq) or (chain,resSeq) (misses leave unchanged)
    """
    per_chain: Dict[str, List[Rule]] = {}
    for r in rules:
        per_chain.setdefault(r[0], []).append(r)

    for ch in per_chain:
        per_chain[ch].sort(key=lambda t: t[1][0])

    def mapper(chain: str, resSeq: int) -> Tuple[str, int]:
        lst = per_chain.get(chain)
        if not lst:
            return chain, resSeq
        for src_chain, (lo, hi), dst_chain, (dlo, dhi) in lst:
            if lo <= resSeq <= hi:
                offset = resSeq - lo
                new_res = dlo + offset
                return dst_chain, new_res
        return chain, resSeq

    return mapper

# ---------------------------
# Core: renumber
# ---------------------------

def renumber_pdb(in_pdb: str, out_pdb: str, rules: List[Rule]) -> None:
    mapper = build_mapper(rules)

    header: List[str] = []
    records: List[Tuple[str, int, int, str]] = []  # (new_chain, new_res, original_idx, new_line)

    with open(in_pdb, "r") as f:
        lines = f.readlines()

    for idx, line in enumerate(lines):
        if not is_atom_record(line):
            header.append(line)
            continue

        old_chain, old_res = get_chain_resseq(line)
        new_chain, new_res = mapper(old_chain, old_res)
        new_line = set_chain_resseq(line, new_chain, new_res)
        records.append((new_chain, new_res, idx, new_line))

    records.sort(key=lambda t: (t[0], t[1], t[2]))

    with open(out_pdb, "w") as w:
        for h in header:
            w.write(h)

        last_chain: Optional[str] = None
        for new_chain, new_res, _, new_line in records:
            if last_chain is None:
                last_chain = new_chain
            elif new_chain != last_chain:
                w.write("TER\n")
                last_chain = new_chain
            w.write(new_line)

        if records:
            w.write("TER\n")
        w.write("END")

# ---------------------------
# CLI
# ---------------------------

def main():
    ap = argparse.ArgumentParser(description="Flexible PDB chain/resSeq remapper")
    ap.add_argument("in_pdb",  type=str, help="input PDB")
    ap.add_argument("out_pdb", type=str, help="output PDB")
    ap.add_argument("--map",   type=str, required=True,
                    help="Mapping DSL, e.g. \"A:1-7->A:120-126; B:8-95->A:1-88; B:96-102->P:200-206\"")
    args = ap.parse_args()

    rules = parse_rules(args.map)
    renumber_pdb(args.in_pdb, args.out_pdb, rules)

if __name__ == "__main__":
    main()
    