import pandas as pd
import requests
import os
import numpy as np
from os.path import join
import argparse


def get_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--train_file",
        type=str,
        default = "",
        help="Path to train file with SMILES strings in the column Ligand_SMILES",
        )
    parser.add_argument(
        "--val_file",
        type=str,
        default = "",
        help="Path to validation file with SMILES strings in the column Ligand_SMILES",
        )
    parser.add_argument(
        "--test_file",
        type=str,
        default = "",
        help="Path to test file with SMILES strings in the column Ligand_SMILES",
        )
    parser.add_argument(
        "--PDB_dir",
        type=str,
        default = "",
        help="Directory with PDB files",
        )
    return parser.parse_args()

args = get_arguments()
args_dict = vars(args)
globals().update(args_dict)


all_UIDs = []
if train_file != "":
    if train_file.endswith(".csv"):
        all_UIDs.extend(pd.read_csv(train_file, sep = ",")["Protein-Id"].tolist())
    elif train_file.endswith(".xlsx"):  
        all_UIDs.extend(pd.read_excel(train_file)["Protein-Id"].tolist())
if val_file != "":
    if val_file.endswith(".csv"):
        all_UIDs.extend(pd.read_csv(val_file, sep = ",")["Protein-Id"].tolist())
    elif val_file.endswith(".xlsx"):
        all_UIDs.extend(pd.read_excel(val_file)["Protein-Id"].tolist())
if test_file != "":
    if test_file.endswith(".csv"):
        all_UIDs.extend(pd.read_csv(test_file, sep = ",")["Protein-Id"].tolist())
    elif test_file.endswith(".xlsx"):
        all_UIDs.extend(pd.read_excel(test_file)["Protein-Id"].tolist())

all_UIDs = list(set(all_UIDs))

print("Number of UIDs: ", len(all_UIDs))

def download_alphafold_pdbs(uids, pdb_dir, skip_existing=True):
    """
    Download AlphaFold PDB files for given UIDs.
    
    Args:
        uids (list): List of UIDs to download PDB files for
        pdb_dir (str): Directory path where PDB files should be saved
        skip_existing (bool): If True, skip downloading if PDB file already exists
    
    Returns:
        list: List of successfully downloaded UIDs
    """
    downloaded_pdbs = os.listdir(pdb_dir) if os.path.exists(pdb_dir) else []
    successful_downloads = []
    
    for k, uid in enumerate(uids):
        pdb_filename = f"{uid}.pdb"
        if skip_existing and pdb_filename in downloaded_pdbs:
            successful_downloads.append(uid)
            continue
            
        url = f"https://alphafold.ebi.ac.uk/files/AF-{uid}-F1-model_v4.pdb"
        try:
            r = requests.get(url)
            r.raise_for_status()  # Raises an HTTPError for bad responses
            
            if not os.path.exists(pdb_dir):
                os.makedirs(pdb_dir)
                
            with open(os.path.join(pdb_dir, pdb_filename), "wb") as f:
                f.write(r.content)
            successful_downloads.append(uid)
        except requests.exceptions.RequestException as e:
            print(f"Failed to download {uid}: {str(e)}")
            
        if k % 100 == 0:
            print(f"Downloaded {k} PDB files")
    return successful_downloads

if not os.path.exists(PDB_dir):
    os.makedirs(PDB_dir)

successful_downloads = download_alphafold_pdbs(all_UIDs, PDB_dir)
print(f"Successfully downloaded {len(successful_downloads)} PDB files")
