from pathlib import Path
import warnings
warnings.filterwarnings("ignore")
import os
import sys
from tqdm import tqdm
# project root = parent of this file's directory
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))
import numpy as np
from flowdock.utils.datasets import get_datasets
from argparse import ArgumentParser
from flowdock.dataset.pdbbind import complex_collate_fn
from omegaconf import OmegaConf
from flowdock.utils.posebusters_utils import calc_posebusters


# KEYS_VALID = ['not_too_far_away', 'no_clashes','no_volume_clash','no_internal_clash']
KEYS_VALID = ['not_too_far_away','no_internal_clash', 'no_clashes','no_volume_clash','is_buried_fraction']


def calc_posebusters_for_data(data,lig_pos,orig_mol):
    # lig_pos_for_posebusters = [pos + data.protein.full_protein_center for pos in lig_pos]
    lig_pos_for_posebusters = lig_pos
    lig_types_for_posebusters = data.ligand.x[:,0] - 1
    pro_types_for_posebusters = data.protein.all_atom_names
    pro_pos_for_posebusters = data.protein.all_atom_pos
    # lig_mol_for_posebusters = [data.ligand.orig_mol] * len(lig_pos)
    lig_mol_for_posebusters = orig_mol
    names = data.name
    posebusters_results = calc_posebusters(lig_pos_for_posebusters, pro_pos_for_posebusters, lig_types_for_posebusters, pro_types_for_posebusters, names, lig_mol_for_posebusters)
    if posebusters_results is None:
        return None
    return np.array([posebusters_results[key] for key in KEYS_VALID if key in posebusters_results.keys()],dtype=object).transpose()



if __name__ == "__main__":
    parser = ArgumentParser(description="Read file form Command line.")
    parser.add_argument("-c", "--config", dest="config_filename",
                        required=True, help="config file with model arguments")
    parser.add_argument("-p", "--paths-config", dest="paths_config_filename",
                        required=True, help="config file with paths")
    parser.add_argument("-i", "--paths_to_predicts", dest="paths_to_predicts",
                        required=True, help="paths to predicts")
    args = parser.parse_args()
    conf = OmegaConf.load(args.config_filename)
    paths_conf = OmegaConf.load(args.paths_config_filename)
    conf = OmegaConf.merge(conf, paths_conf)
    conf.ligand_mask_ratio = 0.
    conf.protein_mask_ratio = 0.
    conf.std_protein_pos = 0
    conf.std_lig_pos = 0
    conf.augm_ligand_transforms = False
    conf.sample_same_complexes_in_batch = False
    conf.randomize_bond_neighbors = False
    predicted_ligand_transforms_path = None
    use_predicted_tr_only = False
    n_preds_to_use = 1

    all_datasets = get_datasets(conf, splits=['test'], return_separately=True, 
                                predicted_ligand_transforms_path=predicted_ligand_transforms_path,
                                use_predicted_tr_only=use_predicted_tr_only,
                                n_preds_to_use=n_preds_to_use,
                                complex_collate_fn=complex_collate_fn,
                                is_train_dataset=False)
    test_datasets = all_datasets['test']
    ids = [12, 7, 13, 17]
    for data_name, dataset in test_datasets.items():
        n_same = []
        # n_same_filters = []
        number_failed = 0
        print(f"Processing dataset {data_name}")
        file_name = f"{data_name.replace('_conf', '')}_conf_final_preds.npy"
        # file_name_allstages = f"{data_name.replace('_conf', '')}_conf_final_preds_allstages.npy"
        file_name_save = f"{data_name.replace('_conf', '')}_conf_final_preds_fast_metrics.npy"
        file_name_save_allstages = f"{data_name.replace('_conf', '')}_conf_final_preds_allstages_fast_metrics.npy"
        predicts = np.load(os.path.join(args.paths_to_predicts, file_name), allow_pickle=True)
        # predicts_allstages = np.load(os.path.join(args.paths_to_predicts, file_name_allstages), allow_pickle=True)
        for data in tqdm(dataset):
            name = data.name.replace('_conf0', '')
            try:
                lig_pos = np.stack([predicts[0][name]["sample_metrics"][i]["pred_pos"] for i in range(len(predicts[0][name]["sample_metrics"]))])
                # lig_pos_allstages = np.stack([predicts_allstages[0][name]["sample_metrics"][i]["pred_pos"] for i in range(len(predicts_allstages[0][name]["sample_metrics"]))])
            except Exception as e:
                print(f"Error in {name}")
                print(e)
                number_failed += 1
                continue
            # posebusters_orig = np.stack([predicts[0][name]["sample_metrics"][i]['posebusters_filters'][ids] for i in range(len(predicts[0][name]["sample_metrics"]))])
            posebusters_results = calc_posebusters_for_data(data,lig_pos,predicts[0][name]["orig_mol"])
            # posebusters_results_allstages = calc_posebusters_for_data(data,lig_pos_allstages,predicts_allstages[0][name]["orig_mol"])
            if posebusters_results is None:
                print(f"Posebusters failed for {name}")
                number_failed += 1
                continue
            for i,r in enumerate(posebusters_results):
                predicts[0][name]["sample_metrics"][i]["posebusters_filters_fast"] = r
                predicts[0][name]["sample_metrics"][i]["posebusters_filters_passed_count_fast"] = (r[:4] == True).sum()
            # for i,r in enumerate(posebusters_results_allstages):
            #     predicts_allstages[0][name]["sample_metrics"][i]["posebusters_filters_fast"] = r
            #     predicts_allstages[0][name]["sample_metrics"][i]["posebusters_filters_passed_count_fast"] = (r[:4] == True).sum()
            # mask_na = posebusters_orig.sum(axis=-1).astype(str) != "<NA>"
            # posebusters_results = posebusters_results[mask_na]
            # posebusters_orig = posebusters_orig[mask_na]
            # n_same.append((posebusters_results[:,:4] == posebusters_orig[:,:4]).all(axis=-1).sum() / posebusters_results.shape[0])
            # n_same_filters.append((posebusters_results[:,:4] == posebusters_orig[:,:4]).sum(axis=0) / posebusters_results.shape[0])
        # n_same_filters = np.stack(n_same_filters)
        # print(f"Dataset {data_name} Percentage of same posebusters: {np.mean(n_same)}")
        print(f"Dataset {data_name} Number of failed: {number_failed}")
        # print(f"Dataset {data_name} Percentage of same filters: {np.mean(n_same_filters, axis=0)}")
        np.save(os.path.join(args.paths_to_predicts, file_name_save), predicts, allow_pickle=True)
        # np.save(os.path.join(args.paths_to_predicts, file_name_save_allstages), predicts_allstages, allow_pickle=True)
