import pickle
import argparse
from tqdm import tqdm
import os
import shutil
import signal
from contextlib import contextmanager
import time
from open_biomed.data.protein import Protein
from open_biomed.datasets.molecule_protein_dataset import CrossDocked
from open_biomed.tasks.aidd_tasks.protein_molecule_docking import VinaDockTask
from open_biomed.utils.config import Config
from open_biomed.data.molecule import molecule_fingerprint_similarity

parser = argparse.ArgumentParser()
parser.add_argument("--resume", action="store_true")
parser.add_argument("--debug", action="store_true")
parser.add_argument("--fixinvalid", action="store_true")
args = parser.parse_args()

print("Loading preds...")
file = "./data/sample_results/train/molcraft_Mixed_CG_CFG_weighted_success"
data = pickle.load(open(f"{file}/filtered_preds.pkl", "rb"))
# data = pickle.load(open("./data/filtered_preds_molcraft_train.pkl", "rb"))
#Debug only!!!
if args.debug:
    data = data[51390:51390 + 1024]
# Then, calculate the scores of samples

if args.debug:
    cfg = Config.from_dict(
        path="./datasets/CrossDocked", 
        pocket_only=True,
        # debug=True if mode == "test" else False,
        debug=args.debug,
        remove_hs=False,
    )
    dataset = CrossDocked(
        cfg=cfg,
        featurizer=None
    )
    dataset, _, _ = dataset.split()
else:
    print("Loading Dataset")
    dataset = pickle.load(open("./data/csd_train.pkl", "rb"))
print(len(dataset))
data = data[:len(dataset)]

vina_task = VinaDockTask(
    docking_tool="autodock_vina",
    mode="score",
    save_ligands=True
)
from multiprocessing import Pool, cpu_count

def process_molecule(args):
    try:
        @contextmanager
        def timeout(time):
            # Signal handler function
            def signal_handler(signum, frame):
                raise TimeoutError()
            
            # Register signal handler
            signal.signal(signal.SIGALRM, signal_handler)
            signal.alarm(time)  # Set alarm
            
            try:
                yield
            finally:
                signal.alarm(0)  # Disable alarm

        i, data_i, dataset_mol_i, protein_file_i = args
        print(i, " Start")
        if len(data_i) == 0:
            return None

        try:
            with timeout(180):  # 3 minutes = 180 seconds
                vina_scores = []
                protein = Protein.from_pdb_file(protein_file_i)
                vina_task = VinaDockTask(docking_tool="autodock_vina", mode="score")
                
                # Calculate vina scores for generated molecules
                for mol in data_i:
                    try:
                        vina_scores.append(vina_task.run(mol, protein)[0][0])
                    except Exception as e:
                        print(e)
                        vina_scores.append(0)
                
                # Calculate vina score for reference molecule
                try:
                    vina_scores.append(vina_task.run(dataset_mol_i, protein)[0][0])
                except Exception as e:
                    print(e)
                    vina_scores.append(0)
                
                # Calculate reference metrics
                ref_vina_score = vina_scores[-1]
                ref_qed = dataset_mol_i.calc_qed()
                ref_sa = dataset_mol_i.calc_sa()
                
                # Store scores for each generated molecule
                mol_scores = []
                for j in range(len(data_i)):
                    score_dict = {
                        "vina_score": vina_scores[j],
                        "qed": data_i[j].calc_qed(),
                        "sa": data_i[j].calc_sa(),
                        "ref_vina_score": ref_vina_score,
                        "ref_qed": ref_qed,
                        "ref_sa": ref_sa
                    }
                    mol_scores.append(score_dict)
                    
                    if i % 1024 <= 5:
                        print(data_i[j], i, j, score_dict)
                print(i, " Finished")
                return mol_scores

        except TimeoutError:
            print(f"Process {i} timed out after 3 minutes")
            return None
    except Exception as e:
        print(e)
        return None

# Resume
if args.resume or args.fixinvalid:
    # all_scores = pickle.load(open("./data/metrics_molcraft_train.pkl", "rb"))
    all_scores = pickle.load(open(f"{file}/fixed_metrics.pkl", "rb"))
    print(len(all_scores))
else:
    all_scores = []

if args.resume:
    start = len(all_scores) // 1024
else:
    start = 0

for j in tqdm(range(start, (len(data) - 1) // 1024 + 1)):
    # Prepare arguments for parallel processing
    from datetime import datetime
    print(f"Current time: {datetime.now().strftime('%m-%d %H:%M:%S.%f')[:-3]}")
    if args.fixinvalid and all_scores[j * 1024] is not None and j > 0:
        continue
    process_args = []
    for i in range(j * 1024, min((j + 1) * 1024, len(data))):
        process_args.append((i, data[i], dataset.molecules[i], dataset.proteins[i]))

    # Use number of CPUs minus 1 to avoid overloading
    num_processes = 16

    # Run parallel processing
    if args.debug:
        scores = list(map(process_molecule, process_args))
    else:
        with Pool(num_processes) as pool:
            try:
                scores = pool.map_async(process_molecule, process_args)
                # Wait for 30 minutes max
                scores = scores.get(timeout=1800)
            except TimeoutError:
                pool.terminate()
                pool.join()
                print("Pool processes terminated due to timeout")
                scores = [None] * len(process_args)
            except Exception as e:
                pool.terminate() 
                pool.join()
                print(f"Pool processes terminated due to error: {e}")
                scores = [None] * len(process_args)

    # scores = list(map(process_molecule, process_args))

    # Filter out None results from empty data entries
    # cur_scores = [s for s in scores if s is not None]
    cur_scores = scores
    if args.fixinvalid:
        for i in range(1024):
            all_scores[j * 1024 + i] = cur_scores[i]
    else:
        all_scores.extend(cur_scores)
    if j % 1 == 0 and not args.debug:
        print("Saving...")
        # shutil.rmtree("./tmp")
        # os.makedirs("./tmp")
        # pickle.dump(all_scores, open("./data/metrics_molcraft_train.pkl", "wb"))
        pickle.dump(all_scores, open(f"{file}/metrics.pkl", "wb"))
# pickle.dump(all_scores, open("./data/metrics_molcraft_train.pkl", "wb"))
pickle.dump(all_scores, open(f"{file}/metrics.pkl", "wb"))