import os
from tqdm import tqdm
from collections import defaultdict
import pandas as pd
from tabulate import tabulate
import json
from datetime import datetime
import numpy as np
import torch
from torch.utils.data import DataLoader
from argparse import ArgumentParser
from omegaconf import OmegaConf
from rdkit.Chem import RemoveAllHs
from flowdock.dataset.pdbbind import complex_collate_fn
from flowdock.dataset.pdbbind_scoring import dummy_ranking_collate_fn
from flowdock.models import FlowDockModel
from flowdock.models.scoring_model import FlowDockScoringModel
from flowdock.utils.datasets import get_datasets
from flowdock.utils.posebusters import get_posebusters_tests, get_posebusters_tests_updated
from flowdock.utils.inference import (
    euler, load_from_checkpoint, run_evaluation, scoring_inference)
from flowdock.utils.spyrmsd import compute_all_isomorphisms
from flowdock.utils.metrics import add_score_results, get_simple_metrics_df, get_final_results_for_df



def main():
    # os.system('ulimit -n 2048')
    parser = ArgumentParser(description="Read file form Command line.")
    parser.add_argument("-c", "--config", dest="config_filename",
                        required=True, help="config file with model arguments")
    parser.add_argument("-p", "--paths-config", dest="paths_config_filename",
                        required=True, help="config file with paths")
    parser.add_argument("-n", "--name", dest="run_name",
                        required=True, help="name and the folder of the inference run")
    args = parser.parse_args()

    # Load main model config
    conf = OmegaConf.load(args.config_filename)
    paths_conf = OmegaConf.load(args.paths_config_filename)
    conf = OmegaConf.merge(conf, paths_conf)

    conf.ligand_mask_ratio = 0.
    conf.protein_mask_ratio = 0.
    conf.std_protein_pos = 0
    conf.std_lig_pos = 0
    conf.augm_ligand_transforms = False
    conf.sample_same_complexes_in_batch = False
    conf.randomize_bond_neighbors = False

    torch.multiprocessing.set_sharing_strategy('file_system')

    torch.manual_seed(conf.seed)

    now = datetime.now()
    date_time = now.strftime("%m.%d.%Y-%H:%M:%S")
    run_name = args.run_name 
    run_name = f'{run_name}_{date_time}'

    # Load model
    model = FlowDockModel(feature_dim=320, num_heads=8, num_transformer_blocks=12,
                          llm_emb_dim=conf.llm_emb_dim,
                          use_time=conf.use_time,
                          dropout_rate=conf.dropout_rate,
                          num_kernel_pos_encoder=conf.num_kernel_pos_encoder,)

    print('Model parameters (M):', sum(p.numel() for p in model.parameters()) / 1e6)
    print()

    score_names_for_metrics = ['random', 'error_estimate_0', 'symm_rmsd']

    solver = euler
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    num_steps = 10
    n_preds_to_use = 100 #100 # 100
    print('n_preds_to_use', n_preds_to_use)
    compute_symm_rmsd = True
    compute_metrics = True

    if conf.use_all_chains:
        conf.batch_limit = 15000
        batch_size = 4
    else:
        conf.batch_limit = 30000
        batch_size = 16
    num_workers = 0

    get_dataloader_docking = lambda dataset: DataLoader(dataset, batch_size=1, shuffle=False, 
                                                        collate_fn=dummy_ranking_collate_fn, num_workers=num_workers)
    get_dataloader_scoring = lambda dataset: DataLoader(dataset, batch_size=batch_size, shuffle=False, 
                                                        collate_fn=complex_collate_fn, num_workers=num_workers)
    dataset_names = conf.test_dataset_types

    print('DATASET NAMES:', dataset_names)

    pipeline = {
        'docking': [
            {
                'model_path': 'pipeline/stage1/',
                'model_kwargs': {},
                'dataset_kwargs': {'n_preds_to_use': n_preds_to_use, 'stage_num': 1},
            },
            {
                'model_path': 'pipeline/stage2/',
                'model_kwargs': {'use_scoring_rollout': True},
                'dataset_kwargs': {'n_preds_to_use': n_preds_to_use, 'stage_num': 2},
            },
            {
                'model_path': 'pipeline/stage3/',
                'model_kwargs': {},
                'dataset_kwargs': {'use_predicted_tr_only': False, 'n_preds_to_use': n_preds_to_use, 'stage_num': 3},
            },
        ],
        'scoring': {
            'model_path': 'pipeline/scoring/',
            'model_kwargs': {'objective': 'ranking'},
            'dataset_kwargs': {'n_preds_to_use': n_preds_to_use},
        }
    }

    # Save config to the run folder
    os.makedirs(os.path.join(conf.inference_results_folder, run_name), exist_ok=True)
    with open(os.path.join(conf.inference_results_folder, run_name, 'config.json'), 'w') as f:
        json.dump(pipeline, f)


    docking_modules = pipeline['docking']
    for module in docking_modules:
        model = FlowDockModel(feature_dim=320, num_heads=8, num_transformer_blocks=12,
                          llm_emb_dim=conf.llm_emb_dim,
                          use_time=conf.use_time,
                          dropout_rate=conf.dropout_rate,
                          num_kernel_pos_encoder=conf.num_kernel_pos_encoder,
                          **module['model_kwargs'],
                          )
        model = load_from_checkpoint(model, os.path.join(conf.results_folder, module['model_path']))
        model.to(device)
        model.eval();
        module['model'] = model

    scoring_model = FlowDockScoringModel(feature_dim=192, num_heads=4, num_transformer_blocks=6,
                                         llm_emb_dim=conf.llm_emb_dim, dropout_rate=conf.dropout_rate,
                                         objective=pipeline['scoring']['model_kwargs']['objective'])
    scoring_model = load_from_checkpoint(scoring_model, os.path.join(conf.results_folder, 
                                                                     pipeline['scoring']['model_path']))
    scoring_model.to(device)
    scoring_model.eval();
    pipeline['scoring']['model'] = scoring_model

    print('Start inference pipeline', run_name)

    for dataset_name in dataset_names:
        predicted_ligand_transforms_path = None

        if compute_metrics:
            metrics_dataset_name = dataset_name.split('_conf')[0]
            conf.test_dataset_types = [metrics_dataset_name]
            test_dataset_for_metrics = get_datasets(conf, splits=['test'], return_separately=True, 
                                        predicted_ligand_transforms_path=None,
                                        is_train_dataset=False,
                                        complex_collate_fn=complex_collate_fn,
                                        n_preds_to_use=1,
                                        use_all_chains=False,
                                        stage_num=None,
                                        )['test']
            print({ds_name: len(ds) for ds_name, ds in test_dataset_for_metrics.items()})
            test_dataset_for_metrics = test_dataset_for_metrics[metrics_dataset_name]
            mol2isomorphisms = None
            name2true_pos = {}
            if compute_symm_rmsd:
                for complex in test_dataset_for_metrics.complexes:
                    try:
                        complex.ligand.orig_mol = RemoveAllHs(complex.ligand.orig_mol, sanitize=True)
                    except Exception as e:
                        complex.ligand.orig_mol = RemoveAllHs(complex.ligand.orig_mol, sanitize=False)
                    name2true_pos[complex.name] = np.copy(complex.ligand.pos) + complex.protein.full_protein_center
                mol2isomorphisms = {complex.name: compute_all_isomorphisms(complex.ligand.orig_mol) for complex in
                                    tqdm(test_dataset_for_metrics, desc='Computing isomorphisms')}

        for stage_idx in [0, 1, 2]:
            module = pipeline['docking'][min(stage_idx, len(pipeline['docking']) - 1)]
            model = module['model']
            model.to(device)

            print(f'Stage {stage_idx + 1}; predicted_ligand_transforms_path:', predicted_ligand_transforms_path)

            # # Load datasets
            conf.use_sorted_batching = True
            conf.test_dataset_types = [dataset_name]
            test_dataset_docking = get_datasets(conf, splits=['test'], return_separately=True, 
                                        predicted_ligand_transforms_path=predicted_ligand_transforms_path,
                                        is_train_dataset=False,
                                        complex_collate_fn=complex_collate_fn,
                                        **module['dataset_kwargs'],
                                        )['test']
            print({ds_name: len(ds) for ds_name, ds in test_dataset_docking.items()})
            test_dataset_docking = test_dataset_docking[dataset_name]

            # Dataloaders
            test_loader = get_dataloader_docking(test_dataset_docking)
            metrics = run_evaluation(test_loader, num_steps=num_steps, solver=solver, model=model, 
                                        compute_metrics=compute_metrics, do_print=False)

            # Save results
            predicted_ligand_transforms_path = os.path.join(conf.inference_results_folder, run_name, f'stage{stage_idx+1}_{dataset_name}.npy')
            np.save(predicted_ligand_transforms_path, [metrics])
            print(f'Saved metrics to {predicted_ligand_transforms_path}')

            # load dataset with predicted ligand transforms
            conf.use_sorted_batching = False
            test_dataset_scoring = get_datasets(conf, splits=['test'],
                                        return_separately=True, 
                                        is_ranking_dataset=True, 
                                        complex_collate_fn=complex_collate_fn,
                                        predicted_complex_positions_path=predicted_ligand_transforms_path,
                                        is_train_dataset=False,
                                        stage_num=None,
                                        **pipeline['scoring']['dataset_kwargs'],
                                        )['test']
            test_dataset_scoring = test_dataset_scoring[dataset_name]
            print('Scoring', dataset_name, len(test_dataset_scoring))
            test_loader = get_dataloader_scoring(test_dataset_scoring)

            pred_scores, pred_tr_errs = scoring_inference(loader=test_loader, model=scoring_model)
            metrics = add_score_results(metrics, pred_scores, score_name='error_estimate', n_samples=None)
            if len(pred_tr_errs) > 0:
                metrics = add_score_results(metrics, pred_tr_errs, score_name='tr_error_estimate', n_samples=None)
            np.save(predicted_ligand_transforms_path, [metrics])

            # Compute metrics
            if compute_metrics:
                real_metrics = defaultdict(list)
                for uid in metrics.keys():
                    for i in range(len(metrics[uid])):
                        real_metrics[uid.split('_conf')[0]].append(metrics[uid][i])
                metrics = real_metrics

                for uid in metrics.keys():
                    for i in range(len(metrics[uid])):
                        metrics[uid][i]['true_pos'] = name2true_pos[uid]
                        metrics[uid][i]['transformed_orig'] = np.copy(
                            metrics[uid][i]['transformed_orig']) + metrics[uid][i]['full_protein_center']

                results_df, _, updated_metrics = get_simple_metrics_df(
                    metrics, compute_symm_rmsd=compute_symm_rmsd, 
                    mol2isomorphisms=mol2isomorphisms, score_names=score_names_for_metrics)
                cols = results_df.columns[[0, 5, 6, 11, 10, 13, 15]]
                print(f'Dataset: {dataset_name}; stage: {stage_idx+1}')
                print(tabulate(results_df[cols], headers='keys', tablefmt='psql'))
                results_df.to_csv(os.path.join(conf.inference_results_folder, run_name, f'{dataset_name}_metrics.csv'), index=False)
                print()
                print()

        if compute_metrics:
            final_metrics_path = os.path.join(conf.inference_results_folder, run_name, f'{dataset_name}_final_preds.npy')
            np.save(final_metrics_path, [updated_metrics])

        print('****' * 30)
        print()


if __name__ == "__main__":
    torch.backends.cuda.matmul.allow_tf32 = False
    main()
