import os
import copy
from tqdm import tqdm
import torch
import numpy as np
import pandas as pd
from scipy.stats import spearmanr
from sklearn.cluster import DBSCAN

from rdkit.Chem import RemoveAllHs
from rdkit.Chem.rdMolAlign import CalcRMS

from flowdock.utils.datasets import get_datasets
from flowdock.utils.rotation import geodesic_average_rotation_matrices, compute_mean_rotation
from flowdock.flowmatching import rotmat_to_q, q_to_rotmat
from flowdock.utils.rmsd import TimeoutException, time_limit
from flowdock.utils.alignment import restore_atom_order
from flowdock.utils.spyrmsd import compute_all_isomorphisms, get_symmetry_rmsd_with_isomorphisms
from flowdock.utils.transforms import find_rigid_alignment, compute_angle_MAE, get_torsion_angles


def sort_samples_by_scores(all_real_rmsds, score_name):
    filtered_samples = {}

    for uid in all_real_rmsds.keys():
        metrics = all_real_rmsds[uid]
        scores = np.array([metr[score_name] for metr in metrics])
        sorted_indices = np.argsort(scores)
        filtered_samples[uid] = [all_real_rmsds[uid][idx]
                                 for idx in sorted_indices]
    return filtered_samples


def get_best_results_by_score(all_results, score_name):
    filtered_results = {}

    for uid in all_results.keys():
        metrics = all_results[uid]
        if score_name == 'random':
            best_index = 0
        else:
            try:
                scores = np.array([metr[score_name]
                                  for metr in metrics['sample_metrics']])
                best_index = np.argmin(scores)
            except:
                import pdb; pdb.set_trace()

        filtered_results[uid] = metrics['sample_metrics'][best_index]
    return filtered_results


def get_tr_rot_tor_pred(all_real_rmsds):
    preds_all = {'tr': {}, 'rot': {}, 'torsion_angles': {}}
    trues_all = {'tr': {}, 'rot': {}, 'torsion_angles': {}}

    for uid in all_real_rmsds.keys():
        metrics = all_real_rmsds[uid]
        # translations
        tr_pred = np.array([metrics[i]['tr_pred_init']
                           for i in range(len(metrics))])
        preds_all['tr'][uid] = tr_pred
        trues_all['tr'][uid] = metrics[0]['tr_true_init']

        # rotations
        rot_matrices = torch.tensor(np.array([metrics[i]['rot_pred_final']
                                              for i in range(len(metrics))]))

        preds_all['rot'][uid] = rot_matrices
        trues_all['rot'][uid] = metrics[0]['rot_true_final']

        # torsion angles
        torsion_angles_pred = np.array([metrics[i]['torsion_angles_pred']
                                        for i in range(len(metrics))])
        preds_all['torsion_angles'][uid] = torsion_angles_pred
        trues_all['torsion_angles'][uid] = metrics[0]['torsion_angles_true']

    return preds_all, trues_all


def compute_tr_mean(tr_preds, tr_true, N_tr, shuffle_tr):
    if shuffle_tr:
        tr_mean = tr_preds[np.random.choice(
            np.arange(len(tr_preds)), N_tr, replace=False)]
    else:
        tr_mean = tr_preds[:N_tr]
    tr_std = np.std(tr_mean, axis=0)
    tr_std = np.linalg.norm(tr_std)
    tr_mean = tr_mean.mean(axis=0)
    tr_err = np.linalg.norm(tr_mean - tr_true)
    return tr_mean, tr_std, tr_err


def compute_rot_mean(rot_preds, rot_true, N_rot, shuffle_rot, agg_type):
    if shuffle_rot:
        rot_mean = rot_preds[np.random.choice(
            np.arange(len(rot_preds)), N_rot, replace=False)]
    else:
        rot_mean = rot_preds[:N_rot]

    if agg_type == 'quaternion':
        quaternions = rotmat_to_q(rot_mean)
        rot_mean_est = q_to_rotmat(
            compute_mean_rotation(quaternions)[None, :])[0].numpy()
    elif agg_type == 'geodesic':
        rot_mean_est = geodesic_average_rotation_matrices(rot_mean)
    else:
        print('Not implemented for', agg_type)

    rot_mean = rot_mean.numpy()
    rot_std = np.sqrt(np.mean(
        [(1 - np.trace(rot_mean_est.T @ rot_mean[i]) / 3) ** 2 for i in range(len(rot_mean))]))

    # rot_std = np.mean([np.trace(rot_mean_est.T @ rot_mean[i]) / 3 for i in range(len(rot_mean))])
    rot_sim = np.trace(rot_mean_est.T @ rot_true) / 3
    return rot_mean_est, rot_std, rot_sim


def get_clustered_results(full_results, score_name_for_tr_clustering):
    for uid in full_results.keys():
        sample_metrics = full_results[uid]['sample_metrics']
        tr_preds = np.array([item['tr_pred'] for item in sample_metrics])
        scores_for_clustering = np.array([item[score_name_for_tr_clustering] for item in sample_metrics])

        clustering = DBSCAN(eps=1, min_samples=5).fit(tr_preds)
        top_score = 1000
        best_scored_cluster = 0
        label_set = set(clustering.labels_)
        if len(label_set) > 1:
            label_set = label_set - set([-1])
        for label in label_set:
            cluster_score = np.median(scores_for_clustering[clustering.labels_ == label])
            if cluster_score < top_score:    
                best_scored_cluster = label
                top_score = cluster_score
                
        selected_cluster = best_scored_cluster

        filtered_sample_metrics = [sample_metrics[i] for i in range(len(sample_metrics)) if clustering.labels_[i] == selected_cluster]
        full_results[uid]['sample_metrics'] = filtered_sample_metrics
    return full_results


def filter_results_by_posebusters(full_results, use_separate_samples=True, oracle=False):
    for uid in full_results.keys():
        if use_separate_samples:
            samples = full_results[uid]['sample_metrics']
        else:
            samples = full_results[uid]

        if oracle:
            pb_filters_name = 'all_posebusters_filters_passed_count'
        else:
            pb_filters_name = 'posebusters_filters_passed_count'

        scores = np.array([sample[pb_filters_name] for sample in samples])
        best_score = max(scores)
        filtered_samples = [sample for sample in samples if sample[pb_filters_name] == best_score]
        if use_separate_samples:
            full_results[uid]['sample_metrics'] = filtered_samples
        else:
            full_results[uid] = filtered_samples
    return full_results


def filter_results_by_fast(full_results, use_separate_samples=True, buried_quantile=None, is_buried_range_threshold=0.5):
    for uid in full_results.keys():
        if use_separate_samples:
            samples = full_results[uid]['sample_metrics']
        else:
            samples = full_results[uid]

        try:
            scores = np.array([sample['posebusters_filters_passed_count_fast'] for sample in samples])
            best_score = max(scores)
            filtered_samples = [sample for sample in samples if sample['posebusters_filters_passed_count_fast'] == best_score]

            if buried_quantile is not None and len(filtered_samples) > 10:
                buried_scores = np.array([sample['posebusters_filters_fast'] for sample in filtered_samples])[:, -1].astype(float)
                is_buried_range = buried_scores.max() - buried_scores.min()
                if is_buried_range > is_buried_range_threshold:
                    buried_threshold = np.quantile(buried_scores, buried_quantile)
                    filtered_samples = [sample for sample in filtered_samples if float(sample['posebusters_filters_fast'][-1]) >= buried_threshold]
        except KeyError as e:
            filtered_samples = samples

        if use_separate_samples:
            full_results[uid]['sample_metrics'] = filtered_samples
        else:
            full_results[uid] = filtered_samples
    return full_results


def filter_empty_results_and_keep_necessary_ids(full_results, use_separate_samples=True, ids_to_keep=None):
    if ids_to_keep is not None:
        all_pred_uids = set([key.split('_mol')[0] for key in full_results.keys()])
        uids_to_pop = [f'{uid}_mol0' for uid in sorted(all_pred_uids - set(ids_to_keep))]
    else:
        uids_to_pop = []

    if len(uids_to_pop) > 0:
        print(f'Pop {len(uids_to_pop)} uids')

    for uid in full_results.keys():
        if len(full_results[uid]) == 0:
            print(f'{uid} has no valid samples')
            uids_to_pop.append(uid)
            continue

        if use_separate_samples:
            samples = full_results[uid]['sample_metrics']
        else:
            samples = full_results[uid]

        if len(samples) == 0:
            print(f'{uid} has no valid samples')
            uids_to_pop.append(uid)
            continue

    for uid in uids_to_pop:
        full_results.pop(uid)

    return full_results
    

def get_final_results_for_df(full_results, score_names, score_name_prefix='', posebusters_filter=False, 
                             fast_filter=False, ids_to_keep=None):
    def get_row(results, score_name, full_score_name, posebusters_filter):
        scored_results = get_best_results_by_score(results, score_name)

        rmsds = np.array([item['rmsd'] for item in scored_results.values()])
        sym_rmsds = np.array([item['symm_rmsd']
                             for item in scored_results.values()])
        tr_errs = np.array([item['tr_err']
                           for item in scored_results.values()])
        rot_sims = np.array([item['rot_sim']
                            for item in scored_results.values()])
        tor_errs = np.array([item['tor_err']
                            for item in scored_results.values()])
        
        row = {
            'ranking': full_score_name,
            'RMSD < 2A': (rmsds <= 2).mean(),
            'RMSD < 5A': (rmsds <= 5).mean(),
            'avg RMSD': rmsds.mean(),
            'median RMSD': np.median(rmsds),
            'SymRMSD < 2A': (sym_rmsds <= 2).mean(),
            'SymRMSD < 5A': (sym_rmsds <= 5).mean(),
            'avg SymRMSD': sym_rmsds.mean(),
            'median SymRMSD': np.median(sym_rmsds),
            'avg tr_err': tr_errs.mean(),
            'median tr_err': np.median(tr_errs),
            'tr_err < 1A': (tr_errs <= 1).mean(),
            'avg rot_sim': rot_sims.mean(),
            'median rot_sim': np.median(rot_sims),
            'avg tor_angle_err': tor_errs.mean(),
            'median tor_angle_err': np.median(tor_errs),
            'num_samples': len(scored_results.values()),
        }
        
        if posebusters_filter:
            posebusters_all = np.array([item['all_posebusters_filters_passed_count']
                              for item in scored_results.values()])
            row['SymRMSD < 2A & PB valid'] = np.logical_and(sym_rmsds < 2, posebusters_all == 27).mean()
        return row, scored_results

    rows_list = []
    all_scored_results = {}

    full_results = filter_empty_results_and_keep_necessary_ids(full_results, use_separate_samples=True, ids_to_keep=ids_to_keep)

    if posebusters_filter:
        filtered_results_posebusters = filter_results_by_posebusters(copy.deepcopy(full_results), oracle=False)
        filtered_results_posebusters_oracle = filter_results_by_posebusters(copy.deepcopy(full_results), oracle=True)

    if fast_filter:
        filtered_results_fast = filter_results_by_fast(copy.deepcopy(full_results))
        filtered_results_fast_with_buried_01 = filter_results_by_fast(copy.deepcopy(filtered_results_fast), 
                                                                   buried_quantile=0.25, is_buried_range_threshold=0.1)
        filtered_results_fast_with_buried_02 = filter_results_by_fast(copy.deepcopy(filtered_results_fast), 
                                                                   buried_quantile=0.25, is_buried_range_threshold=0.2)
        filtered_results_fast_with_buried_03 = filter_results_by_fast(copy.deepcopy(filtered_results_fast), 
                                                                   buried_quantile=0.25, is_buried_range_threshold=0.3)

    for score_name in score_names:
        full_score_name = f'{score_name_prefix}{score_name}'

        row, scored_results = get_row(full_results, score_name, full_score_name, posebusters_filter=posebusters_filter)
        all_scored_results[full_score_name] = scored_results
        rows_list.append(row)

        if posebusters_filter:
            real_score_name = f'{full_score_name}_posebusters'
            row, scored_results = get_row(filtered_results_posebusters, score_name, real_score_name, 
                                          posebusters_filter=posebusters_filter)
            all_scored_results[real_score_name] = scored_results
            rows_list.append(row)

            real_score_name = f'{full_score_name}_posebusters_oracle'
            row, scored_results = get_row(filtered_results_posebusters_oracle, score_name, real_score_name, 
                                          posebusters_filter=posebusters_filter)
            all_scored_results[real_score_name] = scored_results
            rows_list.append(row)

        if fast_filter:
            real_score_name = f'{full_score_name}_fast'
            row, scored_results = get_row(filtered_results_fast, score_name, real_score_name, 
                                          posebusters_filter=posebusters_filter)
            all_scored_results[real_score_name] = scored_results
            rows_list.append(row)

            real_score_name = f'{full_score_name}_fast_with_buried_01'
            row, scored_results = get_row(filtered_results_fast_with_buried_01, score_name, real_score_name, 
                                          posebusters_filter=posebusters_filter)
            all_scored_results[real_score_name] = scored_results
            rows_list.append(row)

            real_score_name = f'{full_score_name}_fast_with_buried_02'
            row, scored_results = get_row(filtered_results_fast_with_buried_02, score_name, real_score_name, 
                                          posebusters_filter=posebusters_filter)
            all_scored_results[real_score_name] = scored_results
            rows_list.append(row)

            real_score_name = f'{full_score_name}_fast_with_buried_03'
            row, scored_results = get_row(filtered_results_fast_with_buried_03, score_name, real_score_name, 
                                          posebusters_filter=posebusters_filter)
            all_scored_results[real_score_name] = scored_results
            rows_list.append(row)
            
    return rows_list, all_scored_results


def get_simple_metrics_df(all_real_rmsds, compute_symm_rmsd, mol2isomorphisms, score_names=['random', 'bin_0', 'symm_rmsd'],
                          score_names_for_tr_clustering=None):
    print('TODO add filter for bad samples')

    full_results = {}
    for uid, samples in tqdm(all_real_rmsds.items(), desc='Computing metrics'):
        samples_results = []
        failed_symm_rmsd_count = 0

        true_pos = samples[0]['true_pos']
        for idx in range(len(samples)):
            pred_pos = samples[idx]['transformed_orig']
            # pred_mol = samples[idx].get('pred_mol', None)

            if mol2isomorphisms is not None and uid.split('_conf')[0] in mol2isomorphisms:
                try:
                    true_pos = true_pos[:len(mol2isomorphisms[uid.split('_conf')[0]][0][0])]
                    pred_pos = pred_pos[:len(mol2isomorphisms[uid.split('_conf')[0]][0][0])]
                except:
                    import pdb; pdb.set_trace()

            if true_pos.shape[0] != pred_pos.shape[0]:
                print(
                    f'{uid}_{idx:<8} true_pos.shape[0] != pred_pos.shape[0]', true_pos.shape, pred_pos.shape)
                continue
                # true_pos = true_pos[:pred_pos.shape[0]]

            tr_pred = pred_pos.mean(axis=0)
            tr_true = true_pos.mean(axis=0)
            tr_err = np.linalg.norm(tr_pred - tr_true)

            R_est, _ = find_rigid_alignment(true_pos, pred_pos)
            R_true = np.eye(3)
            rot_sim = float(np.trace(R_est.T @ R_true) / 3)

            # compute torsion angles
            bond_properties_for_angles = samples[idx]['bond_properties_for_angles']
            if bond_properties_for_angles is not None:
                torsion_angles_true = get_torsion_angles(torch.from_numpy(np.copy(true_pos)),
                                                        bond_atoms_for_angles=bond_properties_for_angles)
                torsion_angles_pred = get_torsion_angles(torch.from_numpy(np.copy(pred_pos)),
                                                        bond_atoms_for_angles=bond_properties_for_angles)
                tor_err = compute_angle_MAE(torsion_angles_pred,
                                            torsion_angles_true,
                                            bond_properties_for_angles['bond_periods'])
            else:
                tor_err = 0

            rmsd = np.sqrt(
                ((true_pos - pred_pos) ** 2).sum(axis=1).sum() / true_pos.shape[0])
            if compute_symm_rmsd and failed_symm_rmsd_count < 3:  # compute symmetry rmsd
                try:
                    # if pred_mol is not None:
                    #     with time_limit(2):
                    #         try:
                    #             pred_mol_restored = restore_atom_order(orig_mol, pred_mol)
                    #             symm_rmsd = CalcRMS(pred_mol_restored, orig_mol)
                    #             print(f'{uid}_{idx:<8} SymmRMSD new: {symm_rmsd:>6.2f}')
                    #         except RuntimeError as e:
                    #             print(f'{uid}_{idx:<8} RuntimeError: {e}')
                    #             symm_rmsd = rmsd
                    #             # failed_symm_rmsd_count += 1
                    # else:
                    #     print('TODO move to else')
                    mol2iso = mol2isomorphisms.get(uid.split('_conf')[0])
                    if mol2iso is None:
                        symm_rmsd = rmsd
                        failed_symm_rmsd_count += 1
                    else:
                        symm_rmsd = get_symmetry_rmsd_with_isomorphisms(
                            true_pos, pred_pos, mol2iso)
                        # print(f'{uid}_{idx:<8} SymmRMSD old: {symm_rmsd:>6.2f}')
                except TimeoutException:
                    symm_rmsd = rmsd
                    failed_symm_rmsd_count += 1
            else:
                symm_rmsd = rmsd

            results = {
                'tr_pred': tr_pred,
                'tr_err': float(tr_err),
                'rot_sim': float(rot_sim),
                'symm_rmsd': float(symm_rmsd),
                'tor_err': float(tor_err),
                'rmsd': float(rmsd),
                'pred_pos': pred_pos,
            }
            for score_name in set(score_names) - {'random', 'symm_rmsd', 'agg_error_estimate'}:
                results[score_name] = float(samples[idx][score_name])
                if score_name == 'model_error_estimate':
                    results[score_name] = -results[score_name]

            if 'agg_error_estimate' in set(score_names):
                results['agg_error_estimate'] = results['model_error_estimate'] + results['error_estimate_0']

            # print(f'{uid}_{idx:<8} RMSD: {rmsd:>6.2f}, SymmRMSD: {symm_rmsd:>6.2f}, Tr err: {tr_err:>6.2f}, Rot sim: {rot_sim:>6.2f}, Angle err: {tor_err:>6.2f}')
            samples_results.append(results)
        samples_results_dict = {
            'sample_metrics': samples_results,
            'tr_std': np.sqrt(np.mean([item['tr_err'] ** 2 for item in samples_results])),
            'rot_std': np.sqrt(np.mean([(1 - item['rot_sim']) ** 2 for item in samples_results])),
            'tor_std': np.sqrt(np.mean([item['tor_err'] ** 2 for item in samples_results])),
            'true_pos': true_pos,
            'orig_mol': samples[0]['orig_mol'],
        }
        if len(samples_results_dict['sample_metrics']) > 0:
            full_results[uid] = samples_results_dict
        else:
            print(f'{uid} has no valid samples')
            print(
                f'{uid} true_pos.shape[0] != pred_pos.shape[0]', true_pos.shape, pred_pos.shape)

    if len(full_results) != len(all_real_rmsds):
        print('Initial length of test_names', len(all_real_rmsds))
        print('Length of full_results', len(full_results))

    rows_list, all_scored_results = get_final_results_for_df(full_results, score_names)

    if score_names_for_tr_clustering is not None:
        for score_name_for_tr_clustering in score_names_for_tr_clustering:
            clustered_results = get_clustered_results(copy.deepcopy(full_results), score_name_for_tr_clustering)
            rows_list_clustered, all_scored_results_clustered = get_final_results_for_df(clustered_results, score_names, score_name_prefix=f'clustered_{score_name_for_tr_clustering}_')

            rows_list += rows_list_clustered
            all_scored_results = {**all_scored_results, **all_scored_results_clustered}

    return pd.DataFrame(rows_list), all_scored_results, full_results


def add_score_results(all_rmsds_new, score_res, score_name, n_samples=None):
    extended_results = {}
    for uid in tqdm(all_rmsds_new.keys(), desc='Adding score results'):
        new_samples = []
        for i in range(len(all_rmsds_new[uid])):
            sample = all_rmsds_new[uid][i]
            sample_scores = np.array(score_res[f'{uid}_{i}'])
            nan_mask = np.isnan(sample_scores).sum(axis=1).astype(bool)
            if nan_mask.sum() > 0:
                if score_name == 'mult':
                    sample_scores[nan_mask, 2] = 6.
                    sample_scores[nan_mask, 0] = 0.
                    sample_scores[nan_mask, 1] = 0.
                elif score_name == 'bin':
                    sample_scores[nan_mask] = 0.
                elif score_name == 'reg':
                    sample_scores[nan_mask] = 50.

            sample_scores = -sample_scores
            if n_samples is None:
                n_samples = len(sample_scores)
            mean_scores = np.mean(sample_scores[:n_samples], axis=0)

            for idx in range(len(mean_scores)):
                sample[f'{score_name}_{idx}'] = mean_scores[idx]

            new_samples.append(sample)
        extended_results[uid] = new_samples
    return extended_results


def get_score_corr(all_rmsds, score_names):
    rmsds = []
    score_arrs = {name: [] for name in score_names}
    for uid in tqdm(all_rmsds.keys()):
        for i in range(len(all_rmsds[uid])):
            sample = all_rmsds[uid][i]
            rmsds.append(sample['rmsd'])
            for score_name in score_names:
                score_arrs[score_name].append(sample[score_name])

    for score_name in score_names:
        score_arrs[score_name] = np.array(score_arrs[score_name])

    for score_name in score_names:
        print(np.isnan(score_arrs[score_name]).sum())
        print(np.round(score_arrs[score_name], 2))
        corr = spearmanr(rmsds, score_arrs[score_name]).correlation
        print(score_name, corr)
    return rmsds, score_arrs[score_name]


def add_transormed_positions(all_rmsds_new):
    extended_results = {}
    for uid in tqdm(all_rmsds_new[0].keys()):
        new_samples = []
        for i in range(len(all_rmsds_new[0][uid])):
            sample = all_rmsds_new[0][uid][i]
            pos = sample['orig_pos_before_augm']
            tr = sample['tr_pred_init']
            rot = sample['rot_pred_final']
            transformed_pos = (pos - pos.mean(axis=0)) @ rot.T + tr
            sample['transformed_orig'] = transformed_pos
            new_samples.append(sample)
        extended_results[uid] = new_samples
    return extended_results


def print_scoring_results(scored_all_rmsds, score_key='bin_0', do_print=True):

    corr_all = []
    corr_filtered_all = []

    best_rmsds = []
    best_filtered_rmsds = []

    for uid in scored_all_rmsds.keys():
        samples = scored_all_rmsds[uid]
        rmsds = np.array([samples[i]['rmsd'] for i in range(len(samples))])
        scores = np.array([samples[i][score_key] for i in range(len(samples))])

        samples = [sample for sample in samples if sample['min_distance'] > 1.6]
        samples = [sample for sample in samples if sample['min_distance'] < 6.5]

        if len(samples) == 0:
            samples = scored_all_rmsds[uid]
        filtered_rmsds = np.array([samples[i]['rmsd']
                                  for i in range(len(samples))])
        filtered_scores = np.array([samples[i][score_key]
                                   for i in range(len(samples))])

        best_rmsd = rmsds[np.argmin(scores)]
        best_filtered_rmsd = filtered_rmsds[np.argmin(filtered_scores)]
        best_rmsds.append(best_rmsd)
        best_filtered_rmsds.append(best_filtered_rmsd)

        corr = spearmanr(scores, rmsds).correlation if len(
            np.unique(scores)) > 1 else 0
        corr_filtered = spearmanr(filtered_scores, filtered_rmsds).correlation if len(
            np.unique(filtered_scores)) > 1 else 0

        corr_all.append(corr)
        corr_filtered_all.append(corr_filtered)


def compute_simple_metrics(train_res_path, data_exp, conf, score_exp=None,
                           score_names=[
                               'random', 'error_estimate_0', 'symm_rmsd'],
                           compute_symm_rmsd=True,
                           fname_ref='_pdbbind_all_rmsds_10steps_1runs_777seed.npy',
                           data_path_pred=None, test_names=None, pred_method_name='diffdock', merge_results=False,
                           dataset_type='pdbbind', compute_metrics_for_init_preds=True,
                           score_names_for_tr_clustering=None):

    mol2isomorphisms = None
    if compute_symm_rmsd:
        conf['test_dataset_types'] = [dataset_type]
        test_dataset = get_datasets(
            conf, splits=['test'], return_separately=False)['test']
        for complex in test_dataset.complexes:
            try:
                complex.ligand.orig_mol = RemoveAllHs(complex.ligand.orig_mol, sanitize=True)
            except Exception as e:
                complex.ligand.orig_mol = RemoveAllHs(complex.ligand.orig_mol, sanitize=False)
        mol2isomorphisms = {complex.name: compute_all_isomorphisms(complex.ligand.orig_mol) for complex in
                            tqdm(test_dataset, desc='Computing isomorphisms')}

    # load predictions
    all_rmsds_new = np.load(os.path.join(
        train_res_path, data_exp, fname_ref), allow_pickle=True)[0]
    if data_path_pred is not None:
        all_rmsds_new_pred = np.load(data_path_pred, allow_pickle=True)[0]

    if test_names is None:
        if data_path_pred is not None:
            test_names = sorted(set(all_rmsds_new.keys()).intersection(
                set(all_rmsds_new_pred.keys())))
        else:
            test_names = sorted(set(all_rmsds_new.keys()))
    print('Length of test_names', len(test_names))

    # filter by test_names
    all_rmsds_new = {uid: all_rmsds_new[uid] for uid in test_names}
    if data_path_pred is not None:
        all_rmsds_new_pred = {
            uid: all_rmsds_new_pred[uid] for uid in test_names}

    for uid in test_names:
        for i in range(len(all_rmsds_new[uid])):
            all_rmsds_new[uid][i]['true_pos'] = np.copy(
                all_rmsds_new[uid][0]['orig_pos_before_augm']) + all_rmsds_new[uid][0]['full_protein_center']
            all_rmsds_new[uid][i]['transformed_orig'] = np.copy(
                all_rmsds_new[uid][i]['transformed_orig']) + all_rmsds_new[uid][i]['full_protein_center']
        if data_path_pred is not None:
            for i in range(len(all_rmsds_new_pred[uid])):
                all_rmsds_new_pred[uid][i]['true_pos'] = np.copy(
                    all_rmsds_new[uid][0]['orig_pos_before_augm']) + all_rmsds_new[uid][0]['full_protein_center']
                all_rmsds_new_pred[uid][i]['transformed_orig'] = np.copy(
                    all_rmsds_new_pred[uid][i]['transformed_orig']) + all_rmsds_new_pred[uid][i]['full_protein_center']
                all_rmsds_new_pred[uid][i]['bond_properties_for_angles'] = copy.deepcopy(
                    all_rmsds_new[uid][0]['bond_properties_for_angles'])

    if score_exp is not None:
        data_path = os.path.join(data_exp, fname_ref)
        score_results_path = os.path.join(train_res_path, score_exp,
                                          f'{"__".join(data_path.split(".npy")[0].split("/"))}_scoring.npy')
        print('score_results_path', score_results_path)
        score_res = np.load(score_results_path, allow_pickle=True).item()
        all_rmsds_new = add_score_results(
            all_rmsds_new, score_res, score_name='error_estimate', n_samples=None)

    if merge_results:
        score_results_path = os.path.join(train_res_path, score_exp,
                                          f'{os.path.basename(data_path_pred).split(".npy")[0]}_scoring.npy')
        print('score_results_path aggregated', score_results_path)
        score_res = np.load(score_results_path, allow_pickle=True).item()
        all_rmsds_new_pred = add_score_results(
            all_rmsds_new_pred, score_res, score_name='error_estimate', n_samples=None)

    all_scored_results_flowdock = {}
    results_df = pd.DataFrame()
    if compute_metrics_for_init_preds:
        results_df, all_scored_results_flowdock = get_simple_metrics_df(
            all_rmsds_new, compute_symm_rmsd=compute_symm_rmsd, 
            mol2isomorphisms=mol2isomorphisms, score_names=score_names,
            score_names_for_tr_clustering=score_names_for_tr_clustering)

    all_scored_results_merged = {}
    all_scored_results_pred = {}
    if data_path_pred is not None:
        results_df['model'] = 'flowdock'
        results_df_pred, all_scored_results_pred = get_simple_metrics_df(
            all_rmsds_new_pred, compute_symm_rmsd=compute_symm_rmsd, 
            mol2isomorphisms=mol2isomorphisms, score_names=score_names)
        results_df_pred['model'] = pred_method_name
        results_df = pd.concat([results_df, results_df_pred])

        if merge_results:
            for uid in test_names:
                all_rmsds_new[uid] += all_rmsds_new_pred[uid]

            print('Merged results', len(all_rmsds_new[test_names[0]]))
            results_df_pred, all_scored_results_merged = get_simple_metrics_df(
                all_rmsds_new, compute_symm_rmsd=compute_symm_rmsd, 
                mol2isomorphisms=mol2isomorphisms, score_names=score_names)
            results_df_pred['model'] = 'merged'
            results_df = pd.concat([results_df, results_df_pred])

    return results_df, all_scored_results_flowdock, all_scored_results_pred, all_scored_results_merged
