import os
import numpy as np
import pandas as pd
from tqdm import tqdm
from omegaconf import OmegaConf
from collections import defaultdict

from deli import save_json, load_json

from rdkit.Chem import RemoveAllHs


from flowdock.utils.posebusters import get_posebusters_tests_updated
from flowdock.utils.metrics import get_final_results_for_df, get_simple_metrics_df
from flowdock.utils.datasets import get_datasets
from flowdock.dataset.pdbbind import complex_collate_fn
from flowdock.utils.spyrmsd import compute_all_isomorphisms
from flowdock.utils.paths import get_dataset_path



if __name__ == '__main__':

    conf = OmegaConf.load('<path_to_config>')
    paths_conf = OmegaConf.load('<path_to_paths_config>')
    conf = OmegaConf.merge(conf, paths_conf)

    conf.use_all_chains = False

    score_names_for_metrics = ['random', 'error_estimate_0', 'symm_rmsd']
    preds_path = conf.inference_results_folder

    for dataset_name in ['astex', 'pdbbind', 'posebusters', 'dockgen_full']:
        dataset_data_dir = get_dataset_path(dataset_name, conf)

        # LOAD dataset-level data
        metrics_dataset_name = dataset_name.split('_conf')[0]
        conf.test_dataset_types = [metrics_dataset_name]
        test_dataset_for_metrics = get_datasets(conf, splits=['test'], return_separately=True, 
                                    predicted_ligand_transforms_path=None,
                                    is_train_dataset=False,
                                    complex_collate_fn=complex_collate_fn,
                                    n_preds_to_use=1,
                                    )['test']
        print({ds_name: len(ds) for ds_name, ds in test_dataset_for_metrics.items()})
        test_dataset_for_metrics = test_dataset_for_metrics[metrics_dataset_name]
        mol2isomorphisms = None
        name2true_pos = {}
        for complex in test_dataset_for_metrics.complexes:
            try:
                complex.ligand.orig_mol = RemoveAllHs(complex.ligand.orig_mol, sanitize=True)
            except Exception as e:
                complex.ligand.orig_mol = RemoveAllHs(complex.ligand.orig_mol, sanitize=False)
            name2true_pos[complex.name] = np.copy(complex.ligand.pos) + complex.protein.full_protein_center
        mol2isomorphisms = {complex.name: compute_all_isomorphisms(complex.ligand.orig_mol) for complex in
                            tqdm(test_dataset_for_metrics, desc='Computing isomorphisms')}
        
        for exp_folder in ['PIPELINE_allchains']:
            print(dataset_name, exp_folder)
        
            exp_path = os.path.join(preds_path, exp_folder)
            final_preds_path = os.path.join(exp_path, f'{dataset_name}_conf_final_preds.npy')
            preds = np.load(final_preds_path, allow_pickle=True).item()
            rows_list, _ = get_final_results_for_df(preds, score_names=score_names_for_metrics, posebusters_filter=False)
            results_df = pd.DataFrame(rows_list)
            print(results_df[['ranking', 'SymRMSD < 2A', 'SymRMSD < 5A', 'tr_err < 1A']])

            updated_metrics = get_posebusters_tests_updated(preds, 
                                                        dataset_name, 
                                                        dataset_data_dir=dataset_data_dir, 
                                                        posebusters_config='redock')

            rows_list, _ = get_final_results_for_df(updated_metrics, score_names=score_names_for_metrics, posebusters_filter=True)
            results_df = pd.DataFrame(rows_list)

            results_df.to_csv(os.path.join(exp_path, f'{dataset_name}_conf_final_metrics.csv'), 
                            index=False)
            print(results_df[['ranking', 'SymRMSD < 2A', 'SymRMSD < 2A & PB valid', 'SymRMSD < 5A', 'tr_err < 1A']])

            np.save(final_preds_path, [updated_metrics])
