from util_preprocess.util_3D_lists import pdb_to_point_dict, point_dict_to_npy, bounding_boxes
import argparse
import logging
from os.path import join
import os
import numpy as np


def get_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--save_dir",
        type=str,
        default = "",
        help="Directory to save the point lists",
        )
    parser.add_argument(
        "--PDB_dir",
        type=str,
        default = "",
        help="Directory with PDB files",
        )
    parser.add_argument(
        "--index",
        type=int,
        default = -1,
        help="Index for subset of PDB files",
        )
    parser.add_argument(
        "--number_of_files",
        type=int,
        default = -1,
        help="Number of PDB files to process in each subset",
        )
    parser.add_argument(
        "--resolution",
        type=float,
        default = 1.0,
        help="Resolution of the input",
        )

    return parser.parse_args()

args = get_arguments()
args_dict = vars(args)
globals().update(args_dict)

log_output_dir = join(save_dir, "log_outputs")
if not os.path.exists(log_output_dir):
    os.makedirs(log_output_dir)

if not os.path.exists(save_dir):
    os.makedirs(save_dir)

if index == -1:
    setting = "generating_numpy_input_lists_from_PDBs"
else:
    setting = "generating_numpy_input_lists_from_PDBs_" + str(index)
# Set up logging to both file and console
log_name = join(log_output_dir, setting + ".log")
logger = logging.getLogger()
logger.setLevel(logging.INFO)

# Create file handler
file_handler = logging.FileHandler(log_name, mode='a')
file_handler.setLevel(logging.INFO)

# Create console handler
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)

# Create formatter
formatter = logging.Formatter('%(levelname)s:%(message)s')
file_handler.setFormatter(formatter)
console_handler.setFormatter(formatter)

# Add handlers to logger
logger.addHandler(file_handler)
logger.addHandler(console_handler)


if not os.path.exists(join(save_dir, "numpy_3D_point_lists")):
    os.makedirs(join(save_dir, "numpy_3D_point_lists"))



PDBs = list(os.listdir(PDB_dir))
PDBs.sort()
logging.info(f"Number of PDBs: {len(PDBs)}")
if index != -1:
    PDBs = PDBs[index*number_of_files : (index+1)*number_of_files]


PDBs = [f.replace(".pdb", "") for f in PDBs]

'''
existing_point_lists = os.listdir(join(save_dir, "numpy_3D_point_lists"))
existing_point_lists = [f for f in existing_point_lists if f.endswith(".npy")]
existing_point_lists = [f.replace(".npy", "") for f in existing_point_lists]
logging.info(f"Number of existing point lists: {len(existing_point_lists)}")
PDBs = [f for f in PDBs if f not in existing_point_lists]
logging.info(f"Number of PDBs to process (not considering existing point lists): {len(PDBs)}")
''';

boxes_dict = {}

for k, ID in enumerate(PDBs):
    name = ID + ".npy"
    try:
        point_dict = pdb_to_point_dict(filename = join(PDB_dir, ID + ".pdb"), resolution = resolution)
        point_list = point_dict_to_npy(point_dict)
        point_list[:,4] = np.round(point_list[:,4], 4)
        np.save(join(save_dir, "numpy_3D_point_lists", name), point_list)

        #find fitting bounding box:
        indices = point_list[:, :3].astype(int)
        r = [np.max(indices[:, 0]), np.max(indices[:, 1]), np.max(indices[:, 2])]
        for box in bounding_boxes:
            if r[0] < box[0] and r[1] < box[1] and r[2] < box[2]:
                boxes_dict[name] = str(box)
                break
    except Exception as e:
        logging.info(f"Error with {ID}: {e}")

    if k % 20 == 0:
        logging.info(f"Processed {k} PDBs")

logging.info(f"Processed {k+1} PDBs")

#add for all PDBS the bounding box (even if already existent):



bounding_boxes_dir = "protein_3D_bounding_boxes"
if index == -1:
    index = 0
if not os.path.exists(join(save_dir, bounding_boxes_dir)):
    os.makedirs(join(save_dir, bounding_boxes_dir))
np.save(join(save_dir, bounding_boxes_dir, "bounding_boxes_"+ str(index) + ".npy"), boxes_dict)