import os
import argparse
import logging
import shutil
import tempfile
import warnings
from pathlib import Path
from typing import List

import numpy as np
import pandas as pd
import torch
from Bio import AlignIO, SeqIO
from Bio.Align.Applications import ClustalOmegaCommandline
from Bio.PDB import PDBIO, PDBParser, Select
from Bio.Seq import Seq
from Bio.SeqIO import SeqRecord
from biopandas.pdb import PandasPdb

warnings.filterwarnings("ignore")
CLUSTAL_OMEGA_EXECUTABLE = shutil.which("clustalo")

AA_MAP = {
    "ALA": "A", "CYS": "C", "ASP": "D", "GLU": "E", "PHE": "F", "GLY": "G",
    "HIS": "H", "ILE": "I", "LYS": "K", "LEU": "L", "MET": "M", "ASN": "N",
    "PRO": "P", "GLN": "Q", "ARG": "R", "SER": "S", "THR": "T", "VAL": "V",
    "TRP": "W", "TYR": "Y"
}



# align seq using ClustalOmega
def run_align_clustalomega(clustal_omega_executable: str,
                           seq1: str = None, seq2: str = None,
                           seqs: List[str] = None) -> List[SeqRecord]:
    """

    Args:
        seq1: sequence of a chain e.g. seqres sequence
        seq2: sequence of a chain e.g. atmseq sequence
        or you can provide a list of strings using seqs
        seqs: e.g. ["seq1", "seq2", ...]
        clustal_omega_executable: (str) path to clustal omega executable
            e.g. "/usr/local/bin/clustal-omega"
    Returns:
        aln_seq_records: (List)
    """
    # assert input
    if seqs is None and (seq1 is None or seq2 is None):
        raise NotImplemented(f"Provide either List of seqs as `seqs` OR a pair of seqs as `seq1` and `seq2`.")

    # generate seq_recs
    seq_rec = [None]
    if seqs:
        seq_rec = [SeqRecord(id=f"seq{i + 1}", seq=Seq(seqs[i]), description="")
                   for i in range(len(seqs))]
    elif seq1 is not None and seq2 is not None:
        seq_rec = [SeqRecord(id=f"seq{1}", seq=Seq(seq1), description=""),
                   SeqRecord(id=f"seq{2}", seq=Seq(seq2), description="")]

    with tempfile.TemporaryDirectory() as tmpdir:
        # executable
        cmd = clustal_omega_executable

        # create input seq fasta file and output file for clustal-omega
        in_file = os.path.join(tmpdir, "seq.fasta")
        out_file = os.path.join(tmpdir, f"aln.fasta")
        with open(in_file, "w") as f:
            SeqIO.write(seq_rec, f, "fasta")
        # create Clustal-Omega commands
        clustalomega_cline = ClustalOmegaCommandline(cmd=cmd, infile=in_file, outfile=out_file, verbose=True, auto=True)

        # run Clustal-Omega
        stdout, stderr = clustalomega_cline()

        # read aln
        aln_seq_records = []
        with open(out_file, "r") as f:
            for record in AlignIO.read(f, "fasta"):
                aln_seq_records.append(record)

        return aln_seq_records
    
# align ATOMSEQ to SEQRES
"""
FIXME: []
- keep log of the antigen seqres with alignment error
"""

def get_seqres2atmseq_mask(seqres, atmseq, pdbid):
    try:
        aln = run_align_clustalomega(
            clustal_omega_executable=CLUSTAL_OMEGA_EXECUTABLE,
            seq1=seqres,
            seq2=atmseq,
        )

        # Check if seqres contains dash
        if "-" in str(aln[0].seq):
            raise ValueError("Error: seqres contains dash")

        aln1 = str(aln[1].seq)  # atmseq in aln may contain "-"
        seqres2atmseq = [
            1 if i != "-" else 0 for i in aln1
        ]  # 1 => in atmseq; 0 => not in atmseq

        # Ensure the lengths match
        if len(seqres2atmseq) != len(seqres):
            raise ValueError("Error: Length mismatch between seqres2atmseq and seqres")

        return seqres2atmseq
    
    except Exception as e:
        # Log the error with the PDB ID
        logging.error(f"PDB ID {pdbid}: {e}")
        return None  # Return None or an empty list to indicate failure




def split_complex_reindex_antibody_chains(pdb_path, pt_graphs_dir, pdb_id, ab_out_dir):
    
    ppdb = PandasPdb().read_pdb(pdb_path)
    atomic_df = ppdb.get_model(1).df["ATOM"]

    mask_data = torch.load(f"{pt_graphs_dir}/{pdb_id}.pt")
    output_path = os.path.join(ab_out_dir, f"{pdb_id}_ab.pdb")

    antibody_chains = {"H", "L"}

    # Process heavy and light chains separately
    chain_data = {}

    # Create a copy of the original DataFrame for antibody chains only
    ab_df = ppdb.df["ATOM"][ppdb.df["ATOM"]["chain_id"].isin(antibody_chains)].copy()
 

    for chain_type in ["H", "L"]:
        if chain_type not in antibody_chains:
            continue

        chain_df = ab_df[ab_df["chain_id"] == chain_type]

        # chain_df = atomic_df[atomic_df["chain_id"] == chain_type]

        if chain_df.empty:
            continue

        # Get SEQRES and ATMSEQ for the chain
        seqres = str(np.array(mask_data["seqres"]["ab"][chain_type]))

        atmseq_df = atomic_df[atomic_df["chain_id"] == chain_type]  # NEW LINE
        atmseq_df = atmseq_df[["residue_number", "residue_name"]].drop_duplicates()

        """
        BUG: 
        - incorrect atmseq (didn't include alternate residues) which lead to incorrect alignment
        - the following code is for correct atmseq filtering
        """

        # Process ATMSEQ with alternates preserved
        # First get ALL residues in original order (including alternates)
        atmseq_full = chain_df.assign(
            full_residue=chain_df["residue_number"].astype(str) + chain_df["insertion"].fillna('')
        )

        # Get ordered unique residues (with alternates)
        residues_ordered = atmseq_full["full_residue"].unique()

        # Now get ATMSEQ string with original residues (including alternates)
        atmseq_df = atmseq_full.drop_duplicates("full_residue")
        atmseq = "".join(atmseq_df["residue_name"].map(AA_MAP))

        # Generate alignment mask
        mask = get_seqres2atmseq_mask(seqres, atmseq, pdb_id)

        # Create full residue identifiers including insertion codes
        chain_df["full_residue"] = chain_df["residue_number"].astype(str) + \
                                chain_df["insertion"].fillna('')

        # Create 1-based consecutive indices for all residues
        new_indices_list = [i for i, bit in enumerate(mask) if bit == 1]
        new_indices = {res: new_index for res, new_index in zip(residues_ordered, new_indices_list)}

        # Apply mapping directly to the DataFrame
        chain_df["new_residue_number"] = chain_df["full_residue"].map(new_indices)

        ab_df.loc[chain_df.index, "residue_number"] = chain_df["new_residue_number"]
        ab_df.loc[chain_df.index, "insertion"] = ""  # Clear insertion codes
     
        # ppdb.df["ATOM"].loc[chain_df.index, "residue_number"] = chain_df["new_residue_number"]

        chain_data[chain_type] = (seqres, atmseq, mask)


    # Save only antibody chains with new numbering
    ppdb_ab = PandasPdb()
    ppdb_ab.df["ATOM"] = ab_df
    ppdb_ab.to_pdb(path=output_path, 
                  records=["ATOM"],
                  gz=False,
                  append_newline=True)


    return chain_data



def main():
    parser = argparse.ArgumentParser(description="Process antibody PDB files with index reset")
    parser.add_argument("input_dir", type=Path, help="Input PDB directory")
    parser.add_argument("pt_graphs_dir", type=Path, help="PyTorch graphs directory")
    parser.add_argument("output_dir", type=Path, help="Output directory for processed PDBs")
    parser.add_argument("metadata_dir", type=Path, help="Output directory for alignment metadata")
    args = parser.parse_args()

    logging.basicConfig(filename=args.metadata_dir/'alignment_errors.log', 
                        level=logging.ERROR,
                        format='%(asctime)s - %(levelname)s - %(message)s')

    metadata_list = []
    for pdb_file in args.input_dir.glob("*.pdb"):
        # print(pdb_file)
        pdb_id = pdb_file.stem.split(".")[0]

        chain_data = split_complex_reindex_antibody_chains(str(pdb_file), args.pt_graphs_dir,
                                 pdb_id, args.output_dir)
        
        if chain_data:
            metadata_entry = {
                "pdb_id": pdb_id,
                "heavy_seqres": chain_data.get("H", (None, None, None))[0],
                "heavy_atmseq": chain_data.get("H", (None, None, None))[1],
                "heavy_seqres2atmseq_mask": chain_data.get("H", (None, None, None))[2],
                "light_seqres": chain_data.get("L", (None, None, None))[0],
                "light_atmseq": chain_data.get("L", (None, None, None))[1],
                "light_seqres2atmseq_mask": chain_data.get("L", (None, None, None))[2],
                "seqres2atmseq_mask":  chain_data.get("H", (None, None, None))[2] +
                                        chain_data.get("L", (None, None, None))[2]
            }
            metadata_list.append(metadata_entry)

    # Save metadata
    pd.DataFrame(metadata_list).to_csv(args.metadata_dir/"seqres2atmseq_mask_ab_HL_chain.csv", index=False)
    print(f"Processed {len(metadata_list)} antibody structures")

if __name__ == "__main__":
    main()



