import sys
import os
import time
import random
from shutil import rmtree
from multiprocessing import Manager
from multiprocessing import Process
from multiprocessing import Queue
import subprocess
from openbabel import pybel
from plip.structure.preparation import PDBComplex


class DockingVina(object):
    def __init__(self, target, residue_file_union=True):
        super().__init__()
        self.target = target
        self.file_path = '/BindMol/data'
        if target == 'fa7':
            self.box_center = (10.131, 41.879, 32.097)
            self.box_size = (20.673, 20.198, 21.362)
        elif target == 'parp1':
            self.box_center = (26.413, 11.282, 27.238)
            self.box_size = (18.521, 17.479, 19.995)
        elif target == '5ht1b':
            self.box_center = (-26.602, 5.277, 17.898)
            self.box_size = (22.5, 22.5, 22.5)
        elif target == 'jak2':
            self.box_center = (114.758,65.496,11.345)
            self.box_size= (19.033,17.929,20.283)
        elif target == 'braf':
            self.box_center = (84.194,6.949,-7.081)
            self.box_size = (22.032,19.211,14.106)
        # Set residue file
        if residue_file_union:
            self.residue_file = f'{self.file_path}/{target}/merged_union.txt'
        else:
            self.residue_file = f'{self.file_path}/{target}/merged_intersection.txt'
        self.vina_program = 'utils_sac/docking/qvina02'
        self.receptor_file = f'utils_sac/docking/{target}.pdbqt'
        self.exhaustiveness = 1
        self.num_sub_proc = 10
        self.num_cpu_dock = 5
        self.num_modes = 10
        self.timeout_gen3d = 30
        self.timeout_dock = 100
        self.target_residues = None  # Will be set when calculating interactions

        # Create unique temporary directory using process ID and timestamp
        process_id = os.getpid()
        timestamp = int(time.time() * 1000)
        tmp_dir = f'tmp/tmp_{process_id}_{timestamp}'
        
        # Ensure tmp directory exists
        if not os.path.exists('tmp'):
            os.makedirs('tmp', exist_ok=True)
            
        # Create unique subdirectory
        os.makedirs(tmp_dir, exist_ok=True)
        self.temp_dir = tmp_dir
        print(f'Docking tmp dir: {tmp_dir}')

    def gen_3d(self, smi, ligand_mol_file):
        """
            generate initial 3d conformation from SMILES
            input :
                SMILES string
                ligand_mol_file (output file)
        """
        # Ensure the output directory exists
        output_dir = os.path.dirname(ligand_mol_file)
        if not os.path.exists(output_dir):
            os.makedirs(output_dir, exist_ok=True)
            
        run_line = 'obabel -:%s --gen3D -O %s' % (smi, ligand_mol_file)
        try:
            result = subprocess.check_output(run_line.split(),
                                             stderr=subprocess.STDOUT,
                                             timeout=self.timeout_gen3d, universal_newlines=True)
            # Check if the file was actually created
            if not os.path.exists(ligand_mol_file) or os.path.getsize(ligand_mol_file) == 0:
                raise RuntimeError(f"obabel failed to create valid output file: {ligand_mol_file}")
        except subprocess.CalledProcessError as e:
            raise RuntimeError(f"obabel command failed: {run_line}\nError: {e.output}")
        except subprocess.TimeoutExpired:
            raise RuntimeError(f"obabel command timed out after {self.timeout_gen3d}s: {run_line}")

    def convert_pdbqt_to_pdb(self, pdbqt_path, output_pdb_path):
        """Convert PDBQT file to PDB format"""
        run_line = 'obabel %s -O %s' % (pdbqt_path, output_pdb_path)
        subprocess.check_output(run_line.split(), stderr=subprocess.STDOUT, universal_newlines=True)

    def merge_protein_ligand(self, protein_pdb, ligand_pdb, output_pdb):
        """Merge protein and ligand PDB files into a complex"""
        with open(output_pdb, 'w') as fout:
            # Write protein
            with open(protein_pdb, 'r') as fprot:
                for line in fprot:
                    if line.startswith(('ATOM', 'HETATM')):
                        fout.write(line)
            
            # Write ligand as HETATM
            with open(ligand_pdb, 'r') as flig:
                for line in flig:
                    if line.startswith(('ATOM', 'HETATM')):
                        if line.startswith('ATOM'):
                            line = 'HETATM' + line[6:]
                        fout.write(line)
            
            fout.write('END\n')

    def calculate_residue_interactions(self, ligand_pdbqt_file, sub_id):
        """Calculate residue interactions for a docked ligand"""
        if self.target_residues is None:
            return {}
        
        # Temporary files
        protein_pdb = f'{self.temp_dir}/protein_{sub_id}.pdb'
        ligand_pdb = f'{self.temp_dir}/ligand_{sub_id}_complex.pdb'
        complex_pdb = f'{self.temp_dir}/complex_{sub_id}.pdb'
        
        try:
            # Convert files and merge
            self.convert_pdbqt_to_pdb(self.receptor_file, protein_pdb)
            self.convert_pdbqt_to_pdb(ligand_pdbqt_file, ligand_pdb)
            self.merge_protein_ligand(protein_pdb, ligand_pdb, complex_pdb)
            
            # PLIP analysis
            protlig = PDBComplex()
            protlig.load_pdb(complex_pdb)
            protlig.analyze()
            
            # Collect residue interaction counts
            residue_counts = {}
            
            for key, interactions in protlig.interaction_sets.items():
                # Collect all types of interactions
                all_interactions = (
                    interactions.hbonds_ldon + interactions.hbonds_pdon +
                    interactions.hydrophobic_contacts + interactions.pistacking +
                    interactions.water_bridges + interactions.saltbridge_lneg +
                    interactions.saltbridge_pneg + interactions.metal_complexes +
                    interactions.halogen_bonds
                )
                
                # Count interactions per residue
                for interaction in all_interactions:
                    chain = interaction.reschain
                    resid = str(interaction.resnr)
                    resname = interaction.restype
                    
                    # Generate residue identifiers
                    full_key = f"{chain}_{resid}_{resname}"
                    
                    # Update count
                    if full_key not in residue_counts:
                        residue_counts[full_key] = 0
                    residue_counts[full_key] += 1
            
            # Filter target residues and return results
            result = {}
            used_residues = set()
            
            # First, process all target residues
            for target in self.target_residues:
                count = 0
                # Try direct match
                if target in residue_counts:
                    count = residue_counts[target]
                    used_residues.add(target)
                else:
                    # Try matching different formats
                    for res_key, res_count in residue_counts.items():
                        if target.startswith('A_') or target.startswith('B_') or target.startswith('C_'):
                            # Target format: A_188_SER
                            if res_key == target:
                                count = res_count
                                used_residues.add(res_key)
                                break
                        else:
                            # Target format: 188_SER, match suffix
                            if res_key.endswith(f"_{target}"):
                                count = res_count
                                used_residues.add(res_key)
                                break
                
                result[target] = count
            
            return result
            
        except Exception as e:
            print(f"Error calculating interactions: {e}")
            result = {target: 0 for target in self.target_residues}
            return result
        finally:
            for temp_file in [protein_pdb, ligand_pdb, complex_pdb]:
                if os.path.exists(temp_file):
                    os.remove(temp_file)

    def docking(self, receptor_file, ligand_mol_file, ligand_pdbqt_file, docking_pdbqt_file):
        """
            run_docking program using subprocess
            input :
                receptor_file
                ligand_mol_file
                ligand_pdbqt_file
                docking_pdbqt_file
            output :
                affinity list for a input molecule
        """
        ms = list(pybel.readfile("mol", ligand_mol_file))
        m = ms[0]
        m.write("pdbqt", ligand_pdbqt_file, overwrite=True)
        run_line = '%s --receptor %s --ligand %s --out %s' % (self.vina_program,
                                                              receptor_file, ligand_pdbqt_file, docking_pdbqt_file)
        run_line += ' --center_x %s --center_y %s --center_z %s' %(self.box_center)
        run_line += ' --size_x %s --size_y %s --size_z %s' %(self.box_size)
        run_line += ' --cpu %d' % (self.num_cpu_dock)
        run_line += ' --num_modes %d' % (self.num_modes)
        run_line += ' --exhaustiveness %d ' % (self.exhaustiveness)
        result = subprocess.check_output(run_line.split(),
                                         stderr=subprocess.STDOUT,
                                         timeout=self.timeout_dock, universal_newlines=True)
        result_lines = result.split('\n')

        check_result = False
        affinity_list = list()
        for result_line in result_lines:
            if result_line.startswith('-----+'):
                check_result = True
                continue
            if not check_result:
                continue
            if result_line.startswith('Writing output'):
                break
            if result_line.startswith('Refine time'):
                break
            lis = result_line.strip().split()
            if not lis[0].isdigit():
                break
            affinity = float(lis[1])
            affinity_list += [affinity]
        
        # Save docked ligand structure if docking was successful
        if len(affinity_list) > 0 and os.path.exists(docking_pdbqt_file):
            # Create output directory if it doesn't exist
            output_dir = f'docked_ligands_{self.target}'
            if not os.path.exists(output_dir):
                os.makedirs(output_dir)
            
            # Generate unique filename for the docked structure
            timestamp = int(time.time() * 1000)  # Use milliseconds for more precision
            random_suffix = random.randint(1000, 9999)  # Add random suffix
            process_id = os.getpid()  # Get current process ID
            saved_file = f'{output_dir}/docked_ligand_pid{process_id}_{timestamp}_{random_suffix}.pdbqt'
            
            # Copy the docked structure to output directory
            from shutil import copy2
            try:
                copy2(docking_pdbqt_file, saved_file)
            except Exception as e:
                print(f'Failed to save docked structure: {e}')
                saved_file = None
        else:
            saved_file = None
        
        return affinity_list, saved_file

    def creator(self, q, data, num_sub_proc):
        """
            put data to queue
            input: queue
                data = [(idx1,smi1), (idx2,smi2), ...]
                num_sub_proc (for end signal)
        """
        for d in data:
            idx = d[0]
            dd = d[1]
            q.put((idx, dd))

        for i in range(0, num_sub_proc):
            q.put('DONE')

    def docking_subprocess(self, q, return_dict, interaction_dict=None, sub_id=0):
        """
            generate subprocess for docking
            input
                q (queue)
                return_dict
                interaction_dict: dict to store interaction results
                sub_id: subprocess index for temp file
        """
        while True:
            qqq = q.get()
            if qqq == 'DONE':
                break
            (idx, smi) = qqq
            # print(smi)
            receptor_file = self.receptor_file
            ligand_mol_file = '%s/ligand_%s.mol' % (self.temp_dir, sub_id)
            ligand_pdbqt_file = '%s/ligand_%s.pdbqt' % (self.temp_dir, sub_id)
            docking_pdbqt_file = '%s/dock_%s.pdbqt' % (self.temp_dir, sub_id)
            try:
                self.gen_3d(smi, ligand_mol_file)
                # Verify that the file was created
                if not os.path.exists(ligand_mol_file):
                    raise FileNotFoundError(f"Failed to create mol file: {ligand_mol_file}")
            except Exception as e:
                print(f"gen_3d error for ligand_mol_file {ligand_mol_file}: {e}")
                print("gen_3d unexpected error:", sys.exc_info())
                print("smiles: ", smi)
                print(f"Working directory: {os.getcwd()}")
                print(f"Temp directory: {self.temp_dir}")
                return_dict[idx] = 99.9
                if interaction_dict is not None:
                    interaction_dict[idx] = {}
                continue
            try:
                affinity_list, saved_file = self.docking(receptor_file, ligand_mol_file,
                                             ligand_pdbqt_file, docking_pdbqt_file)
            except Exception as e:
                print(e)
                print("docking unexpected error:", sys.exc_info())
                print("smiles: ", smi)
                return_dict[idx] = 99.9
                if interaction_dict is not None:
                    interaction_dict[idx] = {}
                continue
            if len(affinity_list)==0:
                affinity_list.append(99.9)
            
            affinity = affinity_list[0]
            return_dict[idx] = affinity
            
            # Calculate interactions if requested
            if interaction_dict is not None and self.target_residues is not None:
                interactions = self.calculate_residue_interactions(docking_pdbqt_file, sub_id)
                interaction_dict[idx] = interactions
                
                # Clean up saved file after interaction analysis
                if saved_file and os.path.exists(saved_file):
                    try:
                        os.remove(saved_file)
                    except Exception as e:
                        print(f'Failed to remove saved file {saved_file}: {e}')

    def predict(self, smiles_list):
        """
            input SMILES list
            output affinity list corresponding to the SMILES list
            if docking is fail, docking score is 99.9
        """
        data = list(enumerate(smiles_list))
        q1 = Queue()
        manager = Manager()
        return_dict = manager.dict()
        proc_master = Process(target=self.creator,
                              args=(q1, data, self.num_sub_proc))
        proc_master.start()

        procs = []
        for sub_id in range(0, self.num_sub_proc):
            proc = Process(target=self.docking_subprocess,
                           args=(q1, return_dict, None, sub_id))
            procs.append(proc)
            proc.start()

        q1.close()
        q1.join_thread()
        proc_master.join()
        for proc in procs:
            proc.join()
        keys = sorted(return_dict.keys())
        affinity_list = list()
        for key in keys:
            affinity = return_dict[key]
            affinity_list += [affinity]
        return affinity_list

    def load_target_residues(self, file_path):
        """Load target residues from file"""
        residues = []
        with open(file_path, 'r') as f:
            for line in f:
                line = line.strip()
                if line:
                    residues.append(line)
        self.target_residues = residues
        return residues

    def predict_with_interactions(self, smiles_list):
        """
        Predict docking scores and residue interactions
        
        Args:
            smiles_list: List of SMILES strings
            target_residues: List of target residues or file path containing residues
        
        Returns:
            tuple: (affinity_list, interaction_list, interaction_values_list)
                - affinity_list: docking scores for each molecule
                - interaction_list: interaction counts for each molecule and residue (dict format)
                - interaction_values_list: interaction counts as lists in residue order
        """
        # Set target residues
        self.load_target_residues(self.residue_file)
        
        data = list(enumerate(smiles_list))
        q1 = Queue()
        manager = Manager()
        return_dict = manager.dict()
        interaction_dict = manager.dict()
        
        proc_master = Process(target=self.creator,
                              args=(q1, data, self.num_sub_proc))
        proc_master.start()

        procs = []
        for sub_id in range(0, self.num_sub_proc):
            proc = Process(target=self.docking_subprocess,
                           args=(q1, return_dict, interaction_dict, sub_id))
            procs.append(proc)
            proc.start()

        q1.close()
        q1.join_thread()
        proc_master.join()
        for proc in procs:
            proc.join()
        
        keys = sorted(return_dict.keys())
        affinity_list = []
        interaction_list = []
        interaction_values_list = []
        for key in keys:
            affinity = return_dict[key]
            interactions = interaction_dict.get(key, {})
            affinity_list.append(affinity)
            interaction_list.append(interactions)
            # Extract interaction values in residue order, including other_residues at the end
            interaction_values = [interactions.get(residue, 0) for residue in self.target_residues]
            # interaction_values.append(interactions.get("other_residues", 0))
            interaction_values_list.append(interaction_values)
        
        return affinity_list, interaction_list, interaction_values_list
    
    def __del__(self):
        if os.path.exists(self.temp_dir):
            rmtree(self.temp_dir)
            print(f'{self.temp_dir} removed')
