# **********************************************************************************
import os
import argparse
import logging
import shutil
import tempfile
import warnings
from pathlib import Path
from typing import List

import numpy as np
import pandas as pd
import torch
from Bio import AlignIO, SeqIO
from Bio.Align.Applications import ClustalOmegaCommandline
from Bio.PDB import PDBIO, PDBParser, Select
from Bio.Seq import Seq
from Bio.SeqIO import SeqRecord
from biopandas.pdb import PandasPdb

warnings.filterwarnings("ignore")

CLUSTAL_OMEGA_EXECUTABLE = shutil.which("clustalo")

# Dictionary for mapping three-letter codes to one-letter amino acid codes
AA_MAP = {
    "ALA": "A", "CYS": "C", "ASP": "D", "GLU": "E", "PHE": "F", "GLY": "G",
    "HIS": "H", "ILE": "I", "LYS": "K", "LEU": "L", "MET": "M", "ASN": "N",
    "PRO": "P", "GLN": "Q", "ARG": "R", "SER": "S", "THR": "T", "VAL": "V",
    "TRP": "W", "TYR": "Y"
}



# align seq using ClustalOmega
def run_align_clustalomega(clustal_omega_executable: str,
                           seq1: str = None, seq2: str = None,
                           seqs: List[str] = None) -> List[SeqRecord]:
    """

    Args:
        seq1: sequence of a chain e.g. seqres sequence
        seq2: sequence of a chain e.g. atmseq sequence
        or you can provide a list of strings using seqs
        seqs: e.g. ["seq1", "seq2", ...]
        clustal_omega_executable: (str) path to clustal omega executable
            e.g. "/usr/local/bin/clustal-omega"
    Returns:
        aln_seq_records: (List)
    """
    # assert input
    if seqs is None and (seq1 is None or seq2 is None):
        raise NotImplemented(f"Provide either List of seqs as `seqs` OR a pair of seqs as `seq1` and `seq2`.")

    # generate seq_recs
    seq_rec = [None]
    if seqs:
        seq_rec = [SeqRecord(id=f"seq{i + 1}", seq=Seq(seqs[i]), description="")
                   for i in range(len(seqs))]
    elif seq1 is not None and seq2 is not None:
        seq_rec = [SeqRecord(id=f"seq{1}", seq=Seq(seq1), description=""),
                   SeqRecord(id=f"seq{2}", seq=Seq(seq2), description="")]

    with tempfile.TemporaryDirectory() as tmpdir:
        # executable
        cmd = clustal_omega_executable

        # create input seq fasta file and output file for clustal-omega
        in_file = os.path.join(tmpdir, "seq.fasta")
        out_file = os.path.join(tmpdir, f"aln.fasta")
        with open(in_file, "w") as f:
            SeqIO.write(seq_rec, f, "fasta")
        # create Clustal-Omega commands
        clustalomega_cline = ClustalOmegaCommandline(cmd=cmd, infile=in_file, outfile=out_file, verbose=True, auto=True)

        # run Clustal-Omega
        stdout, stderr = clustalomega_cline()

        # read aln
        aln_seq_records = []
        with open(out_file, "r") as f:
            for record in AlignIO.read(f, "fasta"):
                aln_seq_records.append(record)

        return aln_seq_records
    
# align ATOMSEQ to SEQRES
"""
FIXME: []
- keep log of the antigen seqres with alignment error
"""

def get_seqres2atmseq_mask(seqres, atmseq, pdbid):
    try:
        aln = run_align_clustalomega(
            clustal_omega_executable=CLUSTAL_OMEGA_EXECUTABLE,
            seq1=seqres,
            seq2=atmseq,
        )

        # Check if seqres contains dash
        if "-" in str(aln[0].seq):
            raise ValueError("Error: seqres contains dash")

        aln1 = str(aln[1].seq)  # atmseq in aln may contain "-"
        seqres2atmseq = [
            1 if i != "-" else 0 for i in aln1
        ]  # 1 => in atmseq; 0 => not in atmseq

        # Ensure the lengths match
        if len(seqres2atmseq) != len(seqres):
            raise ValueError("Error: Length mismatch between seqres2atmseq and seqres")

        return seqres2atmseq
    
    except Exception as e:
        # Log the error with the PDB ID
        logging.error(f"PDB ID {pdbid}: {e}")
        return None  # Return None or an empty list to indicate failure

    


"""
TODO: []
- re-index atmseq based on seqres2atmseq mask 
    - get atmseq and seqres from the pdb file
    - perform pairwise alignment between atmseq and seqres using clustal omega
    - get the atmseq indices from seqres2atmseq mask
    - create temporary mapping to outside the range of the old indices
    - assign the new mapping to the residue number 
"""


def split_complex_reindex_antigen_chains(pdb_path, pt_graphs_dir, pdb_id, ag_out_dir):
    
    ppdb = PandasPdb().read_pdb(pdb_path)
    atomic_df = ppdb.get_model(1).df["ATOM"]

    mask_data = torch.load(f"{pt_graphs_dir}/{pdb_id}.pt")
    output_path = os.path.join(ag_out_dir, f"{pdb_id}_ag.pdb")

    chains = atomic_df["chain_id"].unique()

    ag_chain = chains[2]

    # Process antigen chain
    chain_data = {}

    # Create a copy of the original DataFrame for antigen chains only
    ab_df = ppdb.df["ATOM"][ppdb.df["ATOM"]["chain_id"].isin(list(ag_chain))].copy()

    chain_df = ab_df[ab_df["chain_id"] == ag_chain]

    # Get SEQRES and ATMSEQ for the chain
    seqres = str(np.array(mask_data["seqres"]["ag"][ag_chain]))

    atmseq_df = atomic_df[atomic_df["chain_id"] == ag_chain]  # NEW LINE
    atmseq_df = atmseq_df[["residue_number", "residue_name"]].drop_duplicates()

    """
    BUG: 
    - incorrect atmseq (didn't include alternate residues) which lead to incorrect alignment
    - the following code is for correct atmseq filtering
    """

    # Process ATMSEQ with alternates preserved
    # First get ALL residues in original order (including alternates)
    atmseq_full = chain_df.assign(
        full_residue=chain_df["residue_number"].astype(str) + chain_df["insertion"].fillna('')
    )

    # Get ordered unique residues (with alternates)
    residues_ordered = atmseq_full["full_residue"].unique()

    # Now get ATMSEQ string with original residues (including alternates)
    atmseq_df = atmseq_full.drop_duplicates("full_residue")
    atmseq = "".join(atmseq_df["residue_name"].map(AA_MAP))

    # Generate alignment mask
    mask = get_seqres2atmseq_mask(seqres, atmseq, pdb_id)

    # Create full residue identifiers including insertion codes
    chain_df["full_residue"] = chain_df["residue_number"].astype(str) + \
                            chain_df["insertion"].fillna('')

    # Create 1-based consecutive indices for all residues
    new_indices_list = [i for i, bit in enumerate(mask) if bit == 1]
    new_indices = {res: new_index for res, new_index in zip(residues_ordered, new_indices_list)}

    # Apply mapping directly to the DataFrame
    chain_df["new_residue_number"] = chain_df["full_residue"].map(new_indices)

    ab_df.loc[chain_df.index, "residue_number"] = chain_df["new_residue_number"]
    ab_df.loc[chain_df.index, "insertion"] = ""  # Clear insertion codes
    
    # ppdb.df["ATOM"].loc[chain_df.index, "residue_number"] = chain_df["new_residue_number"]

    chain_data["Ag"] = (seqres, atmseq, mask)


    # Save only antigen chains with new numbering
    ppdb_ab = PandasPdb()
    ppdb_ab.df["ATOM"] = ab_df
    ppdb_ab.to_pdb(path=output_path, 
                  records=["ATOM"],
                  gz=False,
                  append_newline=True)


    return chain_data


def main():
    parser = argparse.ArgumentParser(description="Split an antigen-antibody complex into separate PDB files.")
    parser.add_argument("input_dir", type=Path, help="Input PDB directory")
    parser.add_argument("pt_graphs_dir", type=Path, help="PyTorch graphs directory")
    parser.add_argument("output_dir", type=Path, help="Output directory for processed antigen PDBs")
    parser.add_argument("metadata_dir", type=Path, help="Output directory for alignment metadata")
    args = parser.parse_args()

    # Configure the logging
    logging.basicConfig(filename='alignment_errors.log', level=logging.ERROR,
                        format='%(asctime)s - %(levelname)s - %(message)s')
    

    metadata_list = []
    for pdb_file in args.input_dir.glob("*.pdb"):
        # print(pdb_file)
        pdb_id = pdb_file.stem.split(".")[0]

        chain_data = split_complex_reindex_antigen_chains(str(pdb_file), args.pt_graphs_dir,
                                 pdb_id, args.output_dir)

        if chain_data:
            metadata_entry = {
                "pdb_id": pdb_id,
                "seqres": chain_data.get("Ag", (None, None, None))[0],
                "atmseq": chain_data.get("Ag", (None, None, None))[1],
                "seqres2atmseq_mask": chain_data.get("Ag", (None, None, None))[2]
            }
            metadata_list.append(metadata_entry)

    # Save metadata
    pd.DataFrame(metadata_list).to_csv(args.metadata_dir/"seqres2atmseq_mask_ag.csv", index=False)
    print(f"Processed {len(metadata_list)} antigen structures")



if __name__ == "__main__":
    main()



