import pickle
import argparse
import torch
import numpy as np
import os
import logging
from torch.utils.data import DataLoader
from typing import Optional
from posebusters import PoseBusters
from tqdm import tqdm
from open_biomed.core.pipeline import InferencePipeline
from open_biomed.data.molecule import calc_mol_diversity
from open_biomed.data.pocket import estimate_ligand_atom_num
from open_biomed.utils.config import Config
from open_biomed.datasets.molecule_protein_dataset import CrossDocked
from open_biomed.tasks.aidd_tasks.protein_molecule_docking import pbcheck_single
from open_biomed.tasks.aidd_tasks.structure_based_drug_design import calc_vina_molecule_metrics
from models.wrapper import SBDDModelWrapper
from models.molcraft import MolCRAFT4AdvancedSampling, MolCRAFTWithCFG4AdvancedSampling
from reward import WeightedSuccessReward, WeightedSuccessSmoothedReward

SUPPORTED_MODELS = {
    "molcraft": MolCRAFT4AdvancedSampling,
    "molcraft_cfg": MolCRAFTWithCFG4AdvancedSampling,
}

SUPPORTED_REWARD_FN = {
    "weighted_success": WeightedSuccessReward,
    "weighted_success_smoothed": WeightedSuccessSmoothedReward,
}

pbvalid_mol_keys = [
    "mol_pred_loaded",
    "mol_cond_loaded",
    "sanitization",
    "inchi_convertible",
    "all_atoms_connected",
    "bond_lengths",
    "bond_angles",
    "internal_steric_clash",
    "aromatic_ring_flatness",
    "double_bond_flatness",
    "internal_energy",
]

pbvalid_dock_keys = [
    "protein-ligand_maximum_distance",
    "minimum_distance_to_protein",
    "minimum_distance_to_organic_cofactors",
    "minimum_distance_to_inorganic_cofactors",
    "minimum_distance_to_waters",
    "volume_overlap_with_protein",
    "volume_overlap_with_organic_cofactors",
    "volume_overlap_with_inorganic_cofactors",
    "volume_overlap_with_waters",
]

def select_best(succ_list, vina_list):
    best_succ, best_vina, idx = 0, 0, 0
    for i in range(len(succ_list)):
        if succ_list[i] > best_succ:
            best_succ = succ_list[i]
            best_vina = vina_list[i]
            idx = i
        elif succ_list[i] == best_succ and vina_list[i] < best_vina:
            best_vina = vina_list[i]
            idx = i
    return best_succ, best_vina, idx


class InferenceScalingPipeline(InferencePipeline):
    def __init__(self, 
        model: str="",
        model_ckpt: str="",
        sample_cfg_path: str="",
        reward_cfg_path: str="",
        mode: str="test",
        debug: bool=False,
    ) -> None:
        self.dataset_cfg = Config.from_dict(
            path="./data", 
            debug=True if mode == "test" else False
        )
        self.model_name = model
        self.sample_cfg = Config(config_file=sample_cfg_path)
        self.reward_cfg = Config(config_file=reward_cfg_path)
        self.sample_cfg.name = sample_cfg_path.split("/")[-1].split(".")[0]
        self.reward_cfg.name = reward_cfg_path.split("/")[-1].split(".")[0]
        self.debug = debug
        super().__init__(
            task="structure_based_drug_design",
            model=model.split("_")[0],
            model_ckpt=model_ckpt,
            additional_config=None,
            logging_level="info",
            device="cuda:0",
            output_prompt=None,
            retry_limit=10,
        )
        logging.info(f"Sample config: {self.sample_cfg}")
        logging.info(f"Reward config: {self.reward_cfg}")

    def setup_model(self):
        self.model = SUPPORTED_MODELS[self.model_name](self.cfg.model, self.sample_cfg)
        self.model = SBDDModelWrapper(self.model)
        self.featurizer, self.collator = self.model.get_featurizer()
        logging.info(f"Loading model from {self.cfg.model_ckpt}")
        state_dict = torch.load(open(self.cfg.model_ckpt, "rb"), map_location="cpu")
        if "state_dict" in state_dict:
            state_dict = state_dict["state_dict"]
        if hasattr(self.model.model, "load_ckpt"):
            self.model.model.load_ckpt(state_dict)
        else:
            logging.info(self.model.load_state_dict(state_dict, strict=False))
        self.model.model.eval()
        self.model.to(self.cfg.device)

    # NOTE: batch size is always 1
    # TODO: how to control the sampling hyperparameters to ensure the same inference cost?
    def run(self, num_samples: int=128, resume: bool=False, sample_prior: bool=False, mode: str="test", version: int=0) -> None:
        # TODO: maybe later we will use a different dataset (e.g. PDB pockets with ligands from 2020 to 2024)
        dataset = CrossDocked(
            cfg=self.dataset_cfg,
            featurizer=self.featurizer,
        )
        train, val, test = dataset.split()
        dataset = test if mode == "test" else train
        save_path = os.path.join("./data/sample_results", mode, f"{self.cfg.model.name}_{self.sample_cfg.name}_{self.reward_cfg.name}", str(version))
        if not os.path.exists(save_path):
            os.makedirs(save_path)
        if resume:
            with open(os.path.join(save_path, "preds.pkl"), "rb") as f:
                all_preds = pickle.load(f)
            if mode == "test":
                with open(os.path.join(save_path, "metrics.pkl"), "rb") as f:
                    all_metrics = pickle.load(f)
                with open(os.path.join(save_path, "trajs.pkl"), "rb") as f:
                    all_trajs = pickle.load(f)
                with open(os.path.join(save_path, "div.pkl"), "rb") as f:
                    all_divs = pickle.load(f)
        else:
            all_preds = []
            all_metrics = []
            all_trajs = []
            all_divs = []

        # buster = PoseBusters()
        N = len(dataset)
        for i in tqdm(range(len(all_preds), N), desc="Sampling"):
            if self.debug and i >= 5:
                continue
            # Featurized pocket and non-featurized ligand
            pocket = dataset.pockets[i]
            ligand = dataset.molecules[i]
            if hasattr(pocket, "estimated_num_atoms") and sample_prior:
                delattr(pocket, "estimated_num_atoms")
            if sample_prior:
                estimated_ligand_num = [estimate_ligand_atom_num(pocket) for i in range(num_samples)]
            else:
                estimated_ligand_num = None
            
            if self.reward_cfg.name in SUPPORTED_REWARD_FN and mode == "test":
                reward_fn = SUPPORTED_REWARD_FN[self.reward_cfg.name](pocket, ligand, self.reward_cfg)
            else:
                reward_fn = None
            featurized_pocket = self.featurizer(pocket=pocket)
            featurized_pocket = self.model.transfer_batch_to_device(featurized_pocket, self.cfg.device, 0)
            outputs = self.model.sample(**featurized_pocket, num_samples=num_samples, reward_fn=reward_fn, estimated_ligand_num=estimated_ligand_num)
            all_preds.append(outputs[0])
            if i <= 2:
                for molecule in outputs[0]:
                    print(molecule)
            if mode == "test":
                all_metrics.append([])
                for molecule in outputs[0]:
                    if molecule is None:
                        continue
                    all_metrics[-1].append(calc_vina_molecule_metrics(molecule, dataset.proteins[i]))
                    print(all_metrics[-1][-1])
                    """
                    pb_results = pbcheck_single([molecule], ligand, dataset.proteins[i], buster)
                    all_metrics[-1][-1]["pbvalid_mol"] = True
                    for key in pbvalid_mol_keys:
                        all_metrics[-1][-1]["pbvalid_mol"] &= pb_results[key]
                    all_metrics[-1][-1]["pbvalid_dock"] = True
                    for key in pbvalid_dock_keys:
                        all_metrics[-1][-1]["pbvalid_dock"] &= pb_results[key]
                    all_metrics[-1][-1]["pbvalid"] = all_metrics[-1][-1]["pbvalid_mol"] and all_metrics[-1][-1]["pbvalid_dock"]
                    """
                all_trajs.append(outputs[1:])
                valid_outputs = [mol for mol in outputs[0] if mol is not None]
                if len(valid_outputs) > 1:
                    all_divs.append(calc_mol_diversity(valid_outputs))
            # print(all_metrics, all_divs)
            
            # Save the results
            if not self.debug and (mode == "test" or i % 100 == 0 or i == N - 1):
                print("Saving...")
                if mode == "test":
                    with open(os.path.join(save_path, "preds.pkl"), "wb") as f:
                        pickle.dump(all_preds, f)
                    with open(os.path.join(save_path, "metrics.pkl"), "wb") as f:
                        pickle.dump(all_metrics, f)
                    with open(os.path.join(save_path, "trajs.pkl"), "wb") as f:
                        pickle.dump(all_trajs, f)
                    with open(os.path.join(save_path, "div.pkl"), "wb") as f:
                        pickle.dump(all_divs, f)
                elif mode == "train":
                    shard = i // 2000
                    with open(os.path.join(save_path, f"preds_{shard}.pkl"), "wb") as f:
                        pickle.dump(all_preds[shard * 2000:], f)

        if mode == "test":
            avg_metrics, best_metrics = {}, {}
            for i in range(len(all_preds)):
                if len(all_metrics[i]) == 0:
                    continue
                for j in range(len(all_metrics[i])):
                    for k in all_metrics[i][j]:
                        if k not in avg_metrics:
                            avg_metrics[k] = []
                        avg_metrics[k].append(all_metrics[i][j][k])
                for k in all_metrics[i][0]:
                    if k not in best_metrics:
                        best_metrics[k] = []
                if len(all_metrics[i]) > 0:
                    _, _, best_idx = select_best(
                        [all_metrics[i][j]["success"] for j in range(len(all_metrics[i]))], 
                        [all_metrics[i][j]["vina_dock"] for j in range(len(all_metrics[i]))]
                    )
                    for k in all_metrics[i][best_idx]:
                        best_metrics[k].append(all_metrics[i][best_idx][k])
            metrics = list(avg_metrics.keys())
            for k in metrics:
                if "vina" not in k:
                    avg_metrics[k] = np.mean(avg_metrics[k])
                else:
                    avg_metrics[k + "_mean"] = np.mean(avg_metrics[k])
                    avg_metrics[k + "_median"] = np.median(avg_metrics[k])
                    avg_metrics.pop(k)
            avg_metrics["diversity"] = np.mean(all_divs)
            print("Average metrics:")
            for k, v in avg_metrics.items():
                print(f"{k}: {v:.4f}")
            if num_samples > 1:
                print("Best metrics:")
                for k, v in best_metrics.items():
                    print(f"{k}: {np.mean(v):.4f}")

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str, default="molcraft")
    parser.add_argument("--model_ckpt_path", type=str, default="./checkpoints/molcraft.ckpt")
    parser.add_argument("--sample_cfg_path", type=str, default="./configs/sample/nestedis.yaml")
    parser.add_argument("--reward_cfg_path", type=str, default="./configs/reward/weighted_success.yaml")
    parser.add_argument("--resume", action="store_true")
    parser.add_argument("--debug", action="store_true")
    parser.add_argument("--num_samples", type=int, default=32)
    parser.add_argument("--sample_prior", action="store_true")
    parser.add_argument("--mode", type=str, default="test")
    parser.add_argument("--version", type=str, default="0")
    args = parser.parse_args()

    pipeline = InferenceScalingPipeline(args.model, args.model_ckpt_path, args.sample_cfg_path, args.reward_cfg_path, args.mode, args.debug)
    pipeline.run(num_samples=args.num_samples, resume=args.resume, sample_prior=args.sample_prior, mode=args.mode, version=args.version)

if __name__ == "__main__":
    main()