import argparse
import logging
import os
import re
import wandb

from flexdock.metrics.evaluator import Evaluator


def parse_args():
    parser = argparse.ArgumentParser()

    parser.add_argument("--data_dir", default="/path/to/flexdock/data/PDBBIND_atomCorrected")
    parser.add_argument("--input_csv", default="/path/to/flexdock/examples/inference_pdbbind.csv")
    parser.add_argument(
        "--dataset",
        type=str,
        default="pdbbind",
        choices=["pdbbind", "posebusters", "moad"],
    )
    parser.add_argument("--output_dir", type=str, default="inf_results/dummy_test")

    parser.add_argument("--wandb", default=False, action="store_true")
    parser.add_argument("--project", type=str)
    parser.add_argument("--entity", type=str)

    parser.add_argument("--run_name", type=str, default=None)

    parser.add_argument("--task", default="docking", type=str)
    parser.add_argument("--only_relaxation", action="store_true", default=False)
    parser.add_argument("--use_symmetry_correction", default=False, action="store_true")
    parser.add_argument(
        "--only_nearby_residues_atomic", default=False, action="store_true"
    )
    parser.add_argument("--align_proteins_by", default="nearby_atoms")

    return parser.parse_args()


def main(output_dir=None, input_csv=None, data_dir=None):
    args = parse_args()

    if output_dir is not None:
        args.output_dir = output_dir
    if input_csv is not None:
        args.input_csv = input_csv
    if data_dir is not None:
        args.data_dir = data_dir

    if args.run_name is None:
        args.run_name = args.output_dir + f"_align{args.align_proteins_by}"

    if args.wandb:
        wandb.init(
            entity=args.entity,
            settings=wandb.Settings(start_method="fork"),
            project=args.project,
            name=args.run_name,
            config=args,
        )

    evaluator = Evaluator(args=args)
    evaluator.evaluate(input_csv=args.input_csv, output_dir=args.output_dir)


if __name__ == "__main__":
    logging.getLogger().setLevel("INFO")
    BASE_DIR = "/path/to/flexdock/results_w_proj"

    if os.path.isdir(BASE_DIR):
        subdirs = [
            d for d in os.listdir(BASE_DIR)
            if os.path.isdir(os.path.join(BASE_DIR, d))
        ]
        subdirs.sort(key=lambda s: [int(t) if t.isdigit() else t.lower() for t in re.split(r"(\d+)", s)])
        
        if len(subdirs) > 0:
            for d in subdirs:
                out_dir = os.path.join(BASE_DIR, d)
                main(output_dir=out_dir)
        else:

            main(output_dir=BASE_DIR)
    else:
        main()
