import numpy as np
from openmm.app import PDBFile
import mdtraj as md
from typing import List, Dict, Tuple, Optional
import warnings

class DihedralExtractor:
    
    def __init__(self, pdb_file: str):
        """
        Args:
            pdb_file: 
        """
        self.pdb_file = pdb_file
        pdb = PDBFile(pdb_file)
        self.topology = pdb.topology
        self.positions = pdb.positions
        
        # Dihedral angle definitions
        self.dihedral_definitions = {
            'phi': ['prev_C', 'N', 'CA', 'C'],      # prev_C-N-CA-C 
            'psi': ['N', 'CA', 'C', 'next_N'],      # N-CA-C-next_N 
            'omega': ['CA', 'C', 'next_N', 'next_CA'], # CA-C-next_N-next_CA
            'chi1': ['N', 'CA', 'CB', 'CG'],   # N-CA-CB-CG
            'chi2': ['CA', 'CB', 'CG', 'CD'],  # CA-CB-CG-CD
            'chi3': ['CB', 'CG', 'CD', 'CE'],  # CB-CG-CD-CE
            'chi4': ['CG', 'CD', 'CE', 'NZ'],  # CG-CD-CE-NZ
        }
        
        self.chi_atom_definitions = {
            'ALA': {},  
            'GLY': {}, 
            'SER': {
                'chi1': ['N', 'CA', 'CB', 'OG']  # N-CA-CB-OG
            },
            'THR': {
                'chi1': ['N', 'CA', 'CB', 'OG1']  # N-CA-CB-OG1
            },
            'CYS': {
                'chi1': ['N', 'CA', 'CB', 'SG']  # N-CA-CB-SG
            },
            'VAL': {
                'chi1': ['N', 'CA', 'CB', 'CG1']  # N-CA-CB-CG1
            },
            'ILE': {
                'chi1': ['N', 'CA', 'CB', 'CG1'],  # N-CA-CB-CG1
                'chi2': ['CA', 'CB', 'CG1', 'CD1']  # CA-CB-CG1-CD1
            },
            'LEU': {
                'chi1': ['N', 'CA', 'CB', 'CG'],  # N-CA-CB-CG
                'chi2': ['CA', 'CB', 'CG', 'CD1']  # CA-CB-CG-CD1
            },
            'ASN': {
                'chi1': ['N', 'CA', 'CB', 'CG'],  # N-CA-CB-CG
                'chi2': ['CA', 'CB', 'CG', 'OD1']  # CA-CB-CG-OD1
            },
            'ASP': {
                'chi1': ['N', 'CA', 'CB', 'CG'],  # N-CA-CB-CG
                'chi2': ['CA', 'CB', 'CG', 'OD1']  # CA-CB-CG-OD1
            },
            'GLN': {
                'chi1': ['N', 'CA', 'CB', 'CG'],  # N-CA-CB-CG
                'chi2': ['CA', 'CB', 'CG', 'CD'],  # CA-CB-CG-CD
                'chi3': ['CB', 'CG', 'CD', 'OE1']  # CB-CG-CD-OE1
            },
            'GLU': {
                'chi1': ['N', 'CA', 'CB', 'CG'],  # N-CA-CB-CG
                'chi2': ['CA', 'CB', 'CG', 'CD'],  # CA-CB-CG-CD
                'chi3': ['CB', 'CG', 'CD', 'OE1']  # CB-CG-CD-OE1
            },
            'MET': {
                'chi1': ['N', 'CA', 'CB', 'CG'],  # N-CA-CB-CG
                'chi2': ['CA', 'CB', 'CG', 'SD'],  # CA-CB-CG-SD
                'chi3': ['CB', 'CG', 'SD', 'CE']  # CB-CG-SD-CE
            },
            'HIS': {
                'chi1': ['N', 'CA', 'CB', 'CG'],  # N-CA-CB-CG
                'chi2': ['CA', 'CB', 'CG', 'ND1']  # CA-CB-CG-ND1
            },
            'PHE': {
                'chi1': ['N', 'CA', 'CB', 'CG'],  # N-CA-CB-CG
                'chi2': ['CA', 'CB', 'CG', 'CD1']  # CA-CB-CG-CD1
            },
            'TYR': {
                'chi1': ['N', 'CA', 'CB', 'CG'],  # N-CA-CB-CG
                'chi2': ['CA', 'CB', 'CG', 'CD1']  # CA-CB-CG-CD1
            },
            'TRP': {
                'chi1': ['N', 'CA', 'CB', 'CG'],  # N-CA-CB-CG
                'chi2': ['CA', 'CB', 'CG', 'CD1']  # CA-CB-CG-CD1
            },
            'LYS': {
                'chi1': ['N', 'CA', 'CB', 'CG'],  # N-CA-CB-CG
                'chi2': ['CA', 'CB', 'CG', 'CD'],  # CA-CB-CG-CD
                'chi3': ['CB', 'CG', 'CD', 'CE'],  # CB-CG-CD-CE
                'chi4': ['CG', 'CD', 'CE', 'NZ']  # CG-CD-CE-NZ
            },
            'ARG': {
                'chi1': ['N', 'CA', 'CB', 'CG'],  # N-CA-CB-CG
                'chi2': ['CA', 'CB', 'CG', 'CD'],  # CA-CB-CG-CD
                'chi3': ['CB', 'CG', 'CD', 'NE'],  # CB-CG-CD-NE
                'chi4': ['CG', 'CD', 'NE', 'CZ']  # CG-CD-NE-CZ
            },
            'PRO': {
                'chi1': ['N', 'CA', 'CB', 'CG']  # N-CA-CB-CG
            },
        }

        self.chi_definitions = {
            'ALA': [],  
            'GLY': [], 
            'SER': ['chi1'], 
            'THR': ['chi1'], 
            'CYS': ['chi1'], 
            'VAL': ['chi1'], 
            'ILE': ['chi1', 'chi2'], 
            'LEU': ['chi1', 'chi2'], 
            'ASN': ['chi1', 'chi2'], 
            'ASP': ['chi1', 'chi2'], 
            'GLN': ['chi1', 'chi2', 'chi3'], 
            'GLU': ['chi1', 'chi2', 'chi3'], 
            'MET': ['chi1', 'chi2', 'chi3'], 
            'HIS': ['chi1', 'chi2'], 
            'PHE': ['chi1', 'chi2'], 
            'TYR': ['chi1', 'chi2'], 
            'TRP': ['chi1', 'chi2'], 
            'LYS': ['chi1', 'chi2', 'chi3', 'chi4'], 
            'ARG': ['chi1', 'chi2', 'chi3', 'chi4'], 
            'PRO': ['chi1'], 
        }
    
    def get_atom_indices(self, atom_names: List[str], residue_idx: int) -> List[int]:
        """
        Args:
            atom_names: 
            residue_idx: 
        """
        indices = []
        residue = list(self.topology.residues())[residue_idx]
        
        for atom_name in atom_names:
            found = False
            for atom in residue.atoms():
                if atom.name == atom_name:
                    indices.append(atom.index)
                    found = True
                    break
            if not found:
                indices.append(-1)
                
        return indices
    
    def calculate_dihedral(self, pos1: np.ndarray, pos2: np.ndarray, 
                          pos3: np.ndarray, pos4: np.ndarray) -> float:
        """
        Args:
            pos1, pos2, pos3, pos4: 
        """
        v1 = pos2 - pos1
        v2 = pos3 - pos2
        v3 = pos4 - pos3
        
        n1 = np.cross(v1, v2)
        n2 = np.cross(v2, v3)
        
        n1_norm = np.linalg.norm(n1)
        n2_norm = np.linalg.norm(n2)
        
        if n1_norm == 0 or n2_norm == 0:
            return 0.0
        
        n1 = n1 / n1_norm
        n2 = n2 / n2_norm
        
        cos_angle = np.dot(n1, n2)
        cos_angle = np.clip(cos_angle, -1.0, 1.0)
        
        sign = np.sign(np.dot(np.cross(n1, n2), v2))
        
        angle = np.arccos(cos_angle) * sign
        
        return angle
    
    def get_atom_indices_cross_residue(self, atom_names: List[str], residue_idx: int) -> List[int]:
        """
        Args:
            atom_names: 
            residue_idx: 
        """
        indices = []
        residues = list(self.topology.residues())
        
        if residue_idx >= len(residues):
            return [-1] * len(atom_names)
        
        for atom_name in atom_names:
            found = False
            
            if atom_name.startswith('prev_'):
                if residue_idx > 0:
                    target_residue = residues[residue_idx - 1]
                    target_atom_name = atom_name[5:]
                else:
                    indices.append(-1)
                    continue
            elif atom_name.startswith('next_'):
                if residue_idx < len(residues) - 1:
                    target_residue = residues[residue_idx + 1]
                    target_atom_name = atom_name[5:]
                else:
                    indices.append(-1)
                    continue
            else:
                target_residue = residues[residue_idx]
                target_atom_name = atom_name
            
            for atom in target_residue.atoms():
                if atom.name == target_atom_name:
                    indices.append(atom.index)
                    found = True
                    break
            if not found:
                indices.append(-1)
                
        return indices
    
    def extract_backbone_dihedrals(self, residue_indices: Optional[List[int]] = None) -> Dict[str, List[float]]:
        """
        Args:
        residue_indices: 
        """
        if residue_indices is None:
            residue_indices = list(range(len(list(self.topology.residues()))))
        
        dihedrals = {'phi': [], 'psi': [], 'omega': []}
        
        for res_idx in residue_indices:
            if res_idx >= len(list(self.topology.residues())):
                continue
                
            residue = list(self.topology.residues())[res_idx]
            res_name = residue.name
            
            if res_name in ['ACE', 'NME', 'NHE', 'CHE']:
                continue
            
            phi_atoms = self.get_atom_indices_cross_residue(['prev_C', 'N', 'CA', 'C'], res_idx)
            if -1 not in phi_atoms:
                pos = [self.positions[atom_idx] for atom_idx in phi_atoms]
                phi_angle = self.calculate_dihedral(*pos)
                dihedrals['phi'].append(np.degrees(phi_angle))
            else:
                dihedrals['phi'].append(np.nan)
            
            psi_atoms = self.get_atom_indices_cross_residue(['N', 'CA', 'C', 'next_N'], res_idx)
            if -1 not in psi_atoms:
                pos = [self.positions[atom_idx] for atom_idx in psi_atoms]
                psi_angle = self.calculate_dihedral(*pos)
                dihedrals['psi'].append(np.degrees(psi_angle))
            else:
                dihedrals['psi'].append(np.nan)
            
            omega_atoms = self.get_atom_indices_cross_residue(['CA', 'C', 'next_N', 'next_CA'], res_idx)
            if -1 not in omega_atoms:
                pos = [self.positions[atom_idx] for atom_idx in omega_atoms]
                omega_angle = self.calculate_dihedral(*pos)
                dihedrals['omega'].append(np.degrees(omega_angle))
            else:
                dihedrals['omega'].append(np.nan)
        
        return dihedrals
    
    def extract_sidechain_dihedrals(self, residue_indices: Optional[List[int]] = None) -> Dict[str, List[float]]:
        """
        Args:
            residue_indices: 
        """
        if residue_indices is None:
            residue_indices = list(range(len(list(self.topology.residues()))))
        
        dihedrals = {}
        
        for res_idx in residue_indices:
            if res_idx >= len(list(self.topology.residues())):
                continue
                
            residue = list(self.topology.residues())[res_idx]
            res_name = residue.name
            
            if res_name in ['ACE', 'NME', 'NHE', 'CHE']:
                continue
            
            chi_atom_defs = self.chi_atom_definitions.get(res_name, {})
            
            for chi_type, atom_names in chi_atom_defs.items():
                if chi_type not in dihedrals:
                    dihedrals[chi_type] = []
                
                chi_atoms = self.get_atom_indices(atom_names, res_idx)
                
                if -1 not in chi_atoms:
                    pos = [self.positions[atom_idx] for atom_idx in chi_atoms]
                    chi_angle = self.calculate_dihedral(*pos)
                    dihedrals[chi_type].append(np.degrees(chi_angle))
                else:
                    dihedrals[chi_type].append(np.nan)
        
        return dihedrals
    
    def extract_all_dihedrals(self, residue_indices: Optional[List[int]] = None) -> Dict[str, List[float]]:
        """
        Args:
            residue_indices: 
        """
        backbone_dihedrals = self.extract_backbone_dihedrals(residue_indices)
        sidechain_dihedrals = self.extract_sidechain_dihedrals(residue_indices)
        
        all_dihedrals = {**backbone_dihedrals, **sidechain_dihedrals}
        return all_dihedrals
    
    def get_residue_info(self) -> List[Dict]:
        """
        residue_info: 
        """
        residue_info = []
        for i, residue in enumerate(self.topology.residues()):
            residue_info.append({
                'index': i,
                'name': residue.name,
                'id': residue.id,
                'chain': residue.chain.id if residue.chain else None
            })
        return residue_info
    
    def save_dihedrals_to_file(self, dihedrals: Dict[str, List[float]], 
                              filename: str, residue_info: Optional[List[Dict]] = None):
        """
        Args:
            dihedrals: 
            filename: 
            residue_info: 
        """
        with open(filename, 'w') as f:
            header = "Residue"
            if residue_info:
                header += "\tChain\tResID\tResName"
            for dihedral_type in dihedrals.keys():
                header += f"\t{dihedral_type}"
            f.write(header + "\n")
            
            num_residues = len(next(iter(dihedrals.values())))
            for i in range(num_residues):
                line = f"{i}"
                if residue_info and i < len(residue_info):
                    res_info = residue_info[i]
                    line += f"\t{res_info['chain']}\t{res_info['id']}\t{res_info['name']}"
                for dihedral_type in dihedrals.keys():
                    angle = dihedrals[dihedral_type][i] if i < len(dihedrals[dihedral_type]) else np.nan
                    line += f"\t{angle:.2f}" if not np.isnan(angle) else "\tNaN"
                f.write(line + "\n")
        
        print(f"data saved to: {filename}")


def main():
    pdb_file = "" # TODO: add pdb file
    
    try:
        extractor = DihedralExtractor(pdb_file)
        
        residue_info = extractor.get_residue_info()
        print("residue info:")
        for res in residue_info[:10]:
            print(f"  {res['index']}: {res['name']} (Chain {res['chain']}, ID {res['id']})")
        
        print("\nextract backbone dihedrals...")
        backbone_dihedrals = extractor.extract_backbone_dihedrals()
        
        print("extract sidechain dihedrals...")
        sidechain_dihedrals = extractor.extract_sidechain_dihedrals()
        
        print("merge all dihedrals...")
        all_dihedrals = {**backbone_dihedrals, **sidechain_dihedrals}
        
        print(f"\ndihedral types: {list(all_dihedrals.keys())}")
        for dihedral_type, angles in all_dihedrals.items():
            valid_angles = [a for a in angles if not np.isnan(a)]
            if valid_angles:
                print(f"  {dihedral_type}: {len(valid_angles)} valid values")
                print(f"    range: {min(valid_angles):.1f}° to {max(valid_angles):.1f}°")
        
        print("save to file...")
        extractor.save_dihedrals_to_file(all_dihedrals, "dihedrals.txt", residue_info)
        
    except Exception as e:
        print(f"error: {e}")


if __name__ == "__main__":
    main() 