import copy
from rdkit import Chem

import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""
from tqdm import tqdm
import numpy as np
from omegaconf import OmegaConf

from flowdock.utils.paths import get_dataset_path


if __name__ == '__main__':

    conf = OmegaConf.load('<path_to_config>')
    paths_conf = OmegaConf.load('<path_to_paths_config>')
    conf = OmegaConf.merge(conf, paths_conf)

    for dataset_name in ['astex', 'pdbbind', 'posebusters', 'dockgen_full']:
        a = np.load(f'<path_to_predictions_path>{dataset_name}_conf_final_preds_fast_metrics.npy',
                    allow_pickle=True).item()

        if dataset_name == 'dockgen_full':
            real_dataset_name = 'dockgen'
        else:
            real_dataset_name = dataset_name
        save_path = f'<path_to_save_predictions_path>{real_dataset_name}/'
        
        dataset_path = get_dataset_path(dataset_name, conf)
        os.makedirs(save_path, exist_ok=True)
        for uid, sample_data in tqdm(a.items(), desc='Saving predictions'):

            if len(sample_data) == 0:
                continue
            
            orig_mol = sample_data['orig_mol']
            uid_real = uid.split('_mol')[0]
            
            os.makedirs(os.path.join(save_path, uid_real, 'conf_0'), exist_ok=True)

            samples = sample_data['sample_metrics']#[:40]

            pb_passed_count = np.array([sample.get('posebusters_filters_passed_count_fast', 0) for sample in samples])
            best_pb_count = max(pb_passed_count)
            samples = [sample for sample in samples \
                if sample.get('posebusters_filters_passed_count_fast', 0) == best_pb_count]
            scores = [sample['error_estimate_0'] for sample in samples]
            best_score_idx = np.argmin(scores)
            print(uid, best_pb_count, best_score_idx, len(samples))
            
            best_sample = samples[best_score_idx]
            pred_positions = best_sample['pred_pos']
            mol = copy.deepcopy(orig_mol)
            try:
                mol.GetConformer().SetPositions(pred_positions.astype(np.float64))
                    
                writer = Chem.SDWriter(os.path.join(save_path, uid_real, 'conf_0', f'lig_0.sdf'))
                writer.write(mol, confId=0)
            except Exception as e:
                continue
