import pickle
import argparse
import torch
import numpy as np
import os
import logging
from torch.utils.data import DataLoader
from typing import Optional
from posebusters import PoseBusters
from tqdm import tqdm
from scipy import spatial
from open_biomed.core.pipeline import InferencePipeline
from open_biomed.data.molecule import calc_mol_diversity
from open_biomed.data.pocket import estimate_ligand_atom_num, Pocket, POCKET_CONFIG
from open_biomed.utils.config import Config
from open_biomed.datasets.molecule_protein_dataset import CrossDocked
from open_biomed.tasks.aidd_tasks.structure_based_drug_design import calc_vina_molecule_metrics
from models.wrapper import SBDDModelWrapper
from models.molcraft import MolCRAFT4AdvancedSampling, MolCRAFTWithCFG4AdvancedSampling
from reward import WeightedSuccessReward, WeightedSuccessSmoothedReward

SUPPORTED_MODELS = {
    "molcraft": MolCRAFT4AdvancedSampling,
    "molcraft_cfg": MolCRAFTWithCFG4AdvancedSampling,
}

SUPPORTED_REWARD_FN = {
    "weighted_success": WeightedSuccessReward,
    "weighted_success_smoothed": WeightedSuccessSmoothedReward,
}

def z_score(x):
    return (x - np.mean(x)) / np.std(x)

def estimate_ligand_atom_num_advanced(pocket: Pocket, idx: int=0) -> int:
    dist = spatial.distance.pdist(pocket.conformer, metric='euclidean')
    dist = np.sort(dist)[::-1]
    space_size = np.median(dist[:10])
    bounds = POCKET_CONFIG['bounds']
    for bin_idx in range(len(bounds)):
        if bounds[bin_idx] > space_size:
            break
    if bounds[-1] < space_size:
        bin_idx = len(bounds)
    num_atoms, probs = POCKET_CONFIG['bins'][bin_idx]
    return num_atoms[np.argsort(probs)[::-1][idx]]

class InferenceScalingPipeline(InferencePipeline):
    def __init__(self, 
        model: str="",
        model_ckpt: str="",
        sample_cfg_path: str="",
        reward_cfg_path: str="",
        mode: str="test",
        debug: bool=False,
    ) -> None:
        self.dataset_cfg = Config.from_dict(
            path="./data", 
            debug=True if mode == "test" else False
        )
        self.model_name = model
        self.sample_cfg = Config(config_file=sample_cfg_path)
        self.reward_cfg = Config(config_file=reward_cfg_path)
        self.sample_cfg.name = sample_cfg_path.split("/")[-1].split(".")[0]
        self.reward_cfg.name = reward_cfg_path.split("/")[-1].split(".")[0]
        self.debug = debug
        super().__init__(
            task="structure_based_drug_design",
            model=model.split("_")[0],
            model_ckpt=model_ckpt,
            additional_config=None,
            logging_level="info",
            device="cuda:0",
            output_prompt=None,
            retry_limit=10,
        )
        logging.info(f"Sample config: {self.sample_cfg}")
        logging.info(f"Reward config: {self.reward_cfg}")

    def setup_model(self):
        self.model = SUPPORTED_MODELS[self.model_name](self.cfg.model, self.sample_cfg)
        self.model = SBDDModelWrapper(self.model)
        self.featurizer, self.collator = self.model.get_featurizer()
        logging.info(f"Loading model from {self.cfg.model_ckpt}")
        state_dict = torch.load(open(self.cfg.model_ckpt, "rb"), map_location="cpu")
        if "state_dict" in state_dict:
            state_dict = state_dict["state_dict"]
        if hasattr(self.model.model, "load_ckpt"):
            self.model.model.load_ckpt(state_dict)
        else:
            logging.info(self.model.load_state_dict(state_dict, strict=False))
        self.model.model.eval()
        self.model.to(self.cfg.device)

    # NOTE: batch size is always 1
    # TODO: how to control the sampling hyperparameters to ensure the same inference cost?
    def run(self, num_samples: int=128, resume: bool=False, num_rounds: int=8, sample_mode: str="reference", mode: str="test", version: int=0, until_success: bool=False) -> None:
        # TODO: maybe later we will use a different dataset (e.g. PDB pockets with ligands from 2020 to 2024)
        dataset = CrossDocked(
            cfg=self.dataset_cfg,
            featurizer=self.featurizer,
        )
        train, val, test = dataset.split()
        dataset = test if mode == "test" else train
        save_path = os.path.join("./data/sample_results", mode, f"{self.cfg.model.name}_{self.sample_cfg.name}_{self.reward_cfg.name}", str(version))
        if not os.path.exists(save_path):
            os.makedirs(save_path)
        if resume:
            with open(os.path.join(save_path, "preds.pkl"), "rb") as f:
                all_preds = pickle.load(f)
            if mode == "test":
                with open(os.path.join(save_path, "metrics.pkl"), "rb") as f:
                    all_metrics = pickle.load(f)
                with open(os.path.join(save_path, "div.pkl"), "rb") as f:
                    all_divs = pickle.load(f)
        else:
            all_preds = []
            all_metrics = []
            all_divs = []

        # buster = PoseBusters()
        for i in tqdm(range(len(all_preds), len(dataset)), desc="Sampling"):
            if self.debug and i >= 2:
                continue
            if until_success and i not in [10, 15, 56, 67, 75, 98]:
                continue
            # Featurized pocket and non-featurized ligand
            pocket = dataset.pockets[i]
            ligand = dataset.molecules[i]
            if hasattr(pocket, "estimated_num_atoms") and "prior" in sample_mode:
                delattr(pocket, "estimated_num_atoms")
            
            featurized_pocket = self.featurizer(pocket=pocket)
            featurized_pocket = self.model.transfer_batch_to_device(featurized_pocket, self.cfg.device, 0)
            all_metrics.append([])
            all_preds.append([])
            valid_outputs = []
            # all_preds[i][j][k]: the i-th pocket, the j-th trail (same atom nums), the k-th generation
            if self.debug:
                num_rounds = 3
            elif until_success:
                num_rounds = 64
            for round in range(num_rounds):
                if sample_mode == "prior":
                    estimated_ligand_num = [estimate_ligand_atom_num(pocket) for i in range(num_samples)]
                elif sample_mode == "prior_advanced":
                    assert num_rounds % 3 == 0
                    estimated_ligand_num = [estimate_ligand_atom_num_advanced(pocket, round % 3) for i in range(num_samples)]
                else:
                    estimated_ligand_num = None
                print(estimated_ligand_num)
                outputs = self.model.sample(**featurized_pocket, num_samples=num_samples, reward_fn=None, estimated_ligand_num=estimated_ligand_num)
                if i <= 2:
                    for molecule in outputs[0]:
                        print(molecule)
                flag = False
                for molecule in outputs[0]:
                    if molecule is None:
                        continue
                    metrics = calc_vina_molecule_metrics(molecule, dataset.proteins[i])
                    all_preds[-1].append(molecule)                    
                    all_metrics[-1].append(metrics)
                    if until_success and metrics["success"] > 0:
                        print(f"Success at {round * num_samples}")
                        flag = True
                    print(metrics)
                if flag:
                    break
                valid_outputs.extend([mol for mol in outputs[0] if mol is not None])
            if len(valid_outputs) > 1:
                all_divs.append(calc_mol_diversity(valid_outputs))
            # print(all_metrics, all_divs)
            # assert len(all_preds[i]) == len(all_metrics[i])
            
            # Save the results
            if not self.debug or mode == "test" or i % 100 == 0:
                with open(os.path.join(save_path, "preds.pkl"), "wb") as f:
                    pickle.dump(all_preds, f)
                if mode == "test":
                    with open(os.path.join(save_path, "metrics.pkl"), "wb") as f:
                        pickle.dump(all_metrics, f)
                    with open(os.path.join(save_path, "div.pkl"), "wb") as f:
                        pickle.dump(all_divs, f)

        if mode == "test":
            avg_metrics = {}
            for i in range(len(all_preds)):
                for j in range(len(all_metrics[i])):
                    for l in all_metrics[i][j]:
                        if l not in avg_metrics:
                            avg_metrics[l] = []
                        avg_metrics[l].append(all_metrics[i][j][l])
            
            for i in range(len(all_preds)):
                if len(all_metrics[i]) > 0:
                    qed = [all_metrics[i][j]["qed"] for j in range(len(all_metrics[i])) if all_metrics[i][j]["completeness"] > 0]
                    sa = [all_metrics[i][j]["sa"] for j in range(len(all_metrics[i])) if all_metrics[i][j]["completeness"] > 0]
                    vina = [all_metrics[i][j]["vina_dock"] for j in range(len(all_metrics[i])) if all_metrics[i][j]["completeness"] > 0]
                    qed = z_score(np.array(qed))
                    sa = z_score(np.array(sa))
                    vina = z_score(np.array(vina))
            metrics = list(avg_metrics.keys())
            for k in metrics:
                if "vina" not in k:
                    avg_metrics[k] = np.mean(avg_metrics[k])
                else:
                    avg_metrics[k + "_mean"] = np.mean(avg_metrics[k])
                    avg_metrics[k + "_median"] = np.median(avg_metrics[k])
                    avg_metrics.pop(k)
            avg_metrics["diversity"] = np.mean(all_divs)
            print("Average metrics:")
            for k, v in avg_metrics.items():
                print(f"{k}: {v:.4f}")

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str, default="molcraft")
    parser.add_argument("--model_ckpt_path", type=str, default="./checkpoints/molcraft.ckpt")
    parser.add_argument("--sample_cfg_path", type=str, default="./configs/sample/nestedis.yaml")
    parser.add_argument("--reward_cfg_path", type=str, default="./configs/reward/weighted_success.yaml")
    parser.add_argument("--resume", action="store_true")
    parser.add_argument("--debug", action="store_true")
    parser.add_argument("--num_rounds", type=int, default=8)
    parser.add_argument("--num_samples", type=int, default=32)
    parser.add_argument("--sample_mode", type=str, default="reference")
    parser.add_argument("--mode", type=str, default="test")
    parser.add_argument("--until_success", action="store_true")
    parser.add_argument("--version", type=int, default=0)
    args = parser.parse_args()

    pipeline = InferenceScalingPipeline(args.model, args.model_ckpt_path, args.sample_cfg_path, args.reward_cfg_path, args.mode, args.debug)
    pipeline.run(num_samples=args.num_samples, resume=args.resume, num_rounds=args.num_rounds, sample_mode=args.sample_mode, mode=args.mode, version=args.version, until_success=args.until_success)

if __name__ == "__main__":
    main()