from abc import ABC, abstractmethod
import copy
import time
import numpy as np
import torch
import torch.nn.functional as F
from torch_scatter import scatter_mean
from typing import List, Callable, Tuple
from tqdm import tqdm
from open_biomed.data import Molecule, Pocket
from open_biomed.utils.config import Config
from open_biomed.utils.featurizer import Featurized
from models.molcraft import MolCRAFT
from models.critique_sbdd import CritiqueSBDD

class BFNSampleStrategies(ABC):
    def __init__(self, strategy: Config=None):
        self.strategy = strategy
        if self.strategy.name == "reference":
            self.sample_fn = self.sample_reference
        elif self.strategy.name == "SVDD":
            self.sample_fn = self.sample_SVDD
        elif self.strategy.name == "NestedIS":
            self.sample_fn = self.sample_controlled_NestedIS
        elif self.strategy.name == "IntermediateMonteCarloValue":
            self.sample_fn = self.sample_intermediate_with_monte_carlo_value
        elif self.strategy.name == "GradientGuidance":
            self.sample_fn = self.sample_value_guidance
        elif self.strategy.name == "SVDD_GradientGuidance":
            self.sample_fn = self.sample_svdd_with_value_guidance
        elif self.strategy.name == "Mixed_CG_CFG":
            self.sample_fn = self.sample_mixed_cg_cfg
        elif self.strategy.name == "ClassifierFreeGuidance":
            self.sample_fn = self.sample_classifier_free_guidance
        else:
            raise ValueError(f"Unsupported sample strategy: {self.strategy.name}")

        if self.strategy.name in ["GradientGuidance", "SVDD_GradientGuidance", "Mixed_CG_CFG"]:
            value_model_cfg = Config(self.strategy.value_model_config).model
            self.value_model = CritiqueSBDD(value_model_cfg)
            state_dict = torch.load(self.strategy.value_model_ckpt)["state_dict"]
            # Fix: avoid mutating OrderedDict during iteration by collecting keys to modify/delete first
            keys_to_modify = [key for key in state_dict if "model" in key]
            for key in keys_to_modify:
                state_dict[key.replace("model.", "")] = state_dict[key]
            for key in list(state_dict.keys()):
                if "model" in key:
                    del state_dict[key]
            self.value_model.load_state_dict(state_dict)
            self.value_model.eval()
        if self.strategy.name in ["ClassifierFreeGuidance", "Mixed_CG_CFG"]:
            if self.strategy.bad_model == "ref":
                model_cfg = Config(self.strategy.bad_model_config).model
                self.bad_model = MolCRAFT(model_cfg)
                state_dict = torch.load(self.strategy.bad_model_ckpt)["state_dict"]
                state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()}
                self.bad_model.load_state_dict(state_dict)
                self.bad_model.eval()
            elif self.strategy.bad_model != "prior":
                raise ValueError(f"Unsupported bad model: {self.strategy.bad_model}")

    def prepare_constants(self):
        eta = 1e-5
        self.steps = torch.flip(torch.arange(self.config.num_sample_steps + 1), [0])
        self.times = self.steps.to(torch.float64) / (self.config.num_sample_steps) * (1 - eta)
        self.beta_s_coord  = self.config.sigma1_coord ** (-2 * (self.times)) - 1
        self.gamma_t_coord = 1 - self.config.sigma1_coord ** (2 * (1 - self.times))
        self.alpha_t_coord = 1 - self.config.sigma1_coord ** (2 * (1 - self.times))
        self.sigma_t_coord = (self.alpha_t_coord * (1 - self.alpha_t_coord)).sqrt()
        self.lambda_t_coord = torch.log(self.alpha_t_coord) - torch.log(self.sigma_t_coord)

    @abstractmethod
    def bayesian_update_step(self, molecules: Featurized[Molecule], pockets: Featurized[Pocket], t: torch.Tensor) -> Featurized[Molecule]:
        pass

    @abstractmethod
    def estimate_outcome_molecule(self, molecules: Featurized[Molecule], pockets: Featurized[Pocket], t: torch.Tensor) -> Featurized[Molecule]:
        pass

    @abstractmethod
    def decode_ith_molecule(self, molecules: Featurized[Molecule], pockets: Featurized[Pocket], i: int) -> Molecule:
        pass

    @abstractmethod
    def select_ith_molecule(self, molecules: Featurized[Molecule], i: int) -> Featurized[Molecule]:
        pass

    @torch.no_grad()
    def sample_reference(self, pocket: Featurized[Pocket], num_samples: int=1, reward_fn: Callable=None, estimated_ligand_num: List[int]=None) -> List[Molecule]:
        pockets = self.collators["pocket"]([pocket] * num_samples)
        if estimated_ligand_num is not None:
            pockets["estimated_ligand_num_atoms"] = torch.tensor(estimated_ligand_num, dtype=torch.long, device=pockets["pos"].device)
        molecules = self.create_dummy_molecule(pockets)
        device = molecules['mu_pos'].device
        num_atoms = molecules['mu_pos_batch'].shape[0]
        in_traj, out_traj = [], []
        for step in tqdm(range(1, self.config.num_sample_steps), desc="Sampling"):
            in_traj.append(molecules)
            t = torch.ones((num_atoms, 1), dtype=torch.float, device=device) * (step - 1) / self.config.num_sample_steps
            molecules = self.bayesian_update_step(molecules, pockets, t)
            out_traj.append(molecules)
        t = torch.ones((num_atoms, 1), dtype=torch.float, device=device)
        molecules = self.estimate_outcome_molecule(molecules, pockets, t)
        out_traj.append(molecules)
        return [self.decode_ith_molecule(molecules, pockets, i) for i in range(num_samples)], in_traj, out_traj

    def ode_bfnsolver1_coord(self, x_s: torch.Tensor, x0_pred: torch.Tensor, i: torch.Tensor) -> torch.Tensor:
        lambda_s, lambda_t = self.lambda_t_coord[i], self.lambda_t_coord[i+1]
        alpha_s, alpha_t = self.alpha_t_coord[i], self.alpha_t_coord[i+1]
        sigma_s, sigma_t = self.sigma_t_coord[i], self.sigma_t_coord[i+1]
        h = lambda_t - lambda_s

        noise_pred = (x_s - x0_pred * alpha_t) / sigma_t
        return (alpha_t / alpha_s) * x_s - sigma_t * (torch.exp(h) - 1.0) * noise_pred

    def ode_bfnsolver1_type(self, z_s: torch.Tensor, z0_pred: torch.Tensor, i: torch.Tensor) -> torch.Tensor:
        t_s, t_t = self.times[i], self.times[i+1]
        return (1 - t_t) / (1 - t_s) * z_s + self.config.beta1 * (1 - t_t) * (t_t - t_s) * (1 - self.config.ligand_atom_feature_dim * z0_pred)

    @torch.no_grad()
    def sample_SVDD(self, pocket: Featurized[Pocket], num_samples: int=1, reward_fn: Callable=None, estimated_ligand_num: List[int]=None) -> List[Molecule]:
        pockets = self.collators["pocket"]([pocket] * num_samples)
        if estimated_ligand_num is not None:
            pockets["estimated_ligand_num_atoms"] = torch.tensor(estimated_ligand_num, dtype=torch.long, device=pockets["pos"].device)
        molecules = self.create_dummy_molecule(pockets)
        device = molecules['mu_pos'].device
        num_atoms = molecules['mu_pos_batch'].shape[0]
        in_traj, out_traj, reward_traj, sample_indice_traj = [], [], [], []

        for step in tqdm(range(1, self.config.num_sample_steps), desc="Sampling"):
            in_traj.append(molecules)
            t = torch.ones((num_atoms, 1), dtype=torch.float, device=device) * (step - 1) / self.config.num_sample_steps
            if step % 1000 == 0 and step >= 40 and step <= 70:
                outcome_molecules = self.estimate_outcome_molecule(molecules, pockets, t)
                rewards = []
                for i in range(num_samples):
                    cur_molecule = self.decode_ith_molecule(outcome_molecules, pockets, i)
                    rewards.append(reward_fn(cur_molecule))
                rewards = torch.tensor(rewards, dtype=torch.float, device=device)
                reward_traj.append(rewards)
                temp = self.strategy.temperature / ((step - 40) / (self.config.num_sample_steps - 40) + 1e-6)
                weights = torch.exp((rewards - torch.max(rewards).item()) / temp).cpu().numpy()
                weights = weights / weights.sum()
                final_sample_indices = (np.random.choice(range(num_samples), size=num_samples, p=weights))
                final_sample_indices.sort()
                sample_indice_traj.append(final_sample_indices)
                molecules = [self.select_ith_molecule(outcome_molecules, i) for i in final_sample_indices]
                molecules = self.collators["molecule"](molecules)
            outcome_molecules = self.estimate_outcome_molecule(molecules, pockets, t)
            out_traj.append(outcome_molecules)
            t = torch.ones((num_atoms, 1), dtype=torch.float, device=device) * step / self.config.num_sample_steps
            molecules['mu_pos'], _ = self.continuous_var_bayesian_update(t, outcome_molecules['mu_pos'])
            molecules['theta_h'] = self.discrete_var_bayesian_update(t, outcome_molecules['theta_h'], self.config.ligand_atom_feature_dim)
            # molecules = self.bayesian_update_step(molecules, pockets, t)
        
        decoded_outcome_molecules = []
        t = torch.ones((num_atoms, 1), dtype=torch.float, device=device)
        outcome_molecules = self.estimate_outcome_molecule(molecules, pockets, t)
        for i in range(num_samples):
            decoded_outcome_molecules.append(self.decode_ith_molecule(outcome_molecules, pockets, i))
        return decoded_outcome_molecules, in_traj, out_traj, reward_traj, sample_indice_traj

    # NOTE: to avoid reward hacking on VinaDock, we fix the number of atoms
    # NOTE: we keep the batch size as 1
    # NOTE: the reward function is bound to a protein and a reference molecule
    @torch.no_grad()
    def sample_controlled_NestedIS(self, pocket: Featurized[Pocket], num_samples: int=1, reward_fn: Callable=None, estimated_ligand_num: List[int]=None) -> List[Molecule]:
        # Create num_samples copies of the pocket
        pockets = self.collators["pocket"]([pocket] * num_samples)
        if estimated_ligand_num is not None:
            pockets["estimated_ligand_num_atoms"] = torch.tensor(estimated_ligand_num, dtype=torch.long, device=pockets["pos"].device)
        molecules = self.create_dummy_molecule(pockets)
        device = molecules['mu_pos'].device
        num_atoms = molecules['mu_pos_batch'].shape[0]
        in_traj, out_traj, reward_traj, sample_indice_traj = [], [], [], []
        
        for step in tqdm(range(1, self.config.num_sample_steps + 1), desc="Sampling"):
            # Maybe we can perform search every 10 steps
            in_traj.append(molecules)
            # propose multiple candidates
            if step % 10 == 0 and step <= 0:
                updated_molecules, outcome_molecules = [], []
                decoded_outcome_molecules, rewards = [[] for _ in range(num_samples)], [[] for _ in range(num_samples)]
                for i in range(self.strategy.num_candidates):
                    t = torch.ones((num_atoms, 1), dtype=torch.float, device=device) * (step - 1) / self.config.num_sample_steps
                    start_time = time.time()
                    updated_molecules.append(self.bayesian_update_step(molecules, pockets, t))
                    update_time = time.time() - start_time
                    # print(f"Bayesian update time: {update_time:.3f}s")

                    # NOTE: ideally, we should develop a value model for estimating V_\pi[t+1:T]=E_{\pi[t+1:T]} r(x_T|x_t)
                    # estimate the outcome molecule E_\pi[t+1:T] x_T
                    t = torch.ones((num_atoms, 1), dtype=torch.float, device=device) * step / self.config.num_sample_steps
                    
                    start_time = time.time()
                    outcome_molecules.append(self.estimate_outcome_molecule(updated_molecules[-1], pockets, t))
                    estimate_time = time.time() - start_time
                    # print(f"Estimate outcome time: {estimate_time:.3f}s")
                    # calculate the reward
                    start_time = time.time()
                    for j in range(num_samples):
                        cur_molecule = self.decode_ith_molecule(outcome_molecules[-1], pockets, j)
                        decoded_outcome_molecules[j].append(cur_molecule)
                        rewards[j].append(reward_fn(cur_molecule))
                    reward_time = time.time() - start_time
                    # print(f"Reward time: {reward_time:.3f}s")
                
                out_traj.append(updated_molecules)
                rewards = torch.tensor(rewards, dtype=torch.float, device=device)
                reward_traj.append(rewards)
                print(rewards) 

                # select the best candidate
                best_indices = torch.argmax(rewards, dim=-1)
                updated_molecules = [self.select_ith_molecule(updated_molecules[best_indices[i]], i) for i in range(num_samples)]
                decoded_outcome_molecules = [decoded_outcome_molecules[i][best_indices[i]] for i in range(num_samples)]
                rewards = torch.max(rewards, dim=-1)[0]

                # global resampling
                # TODO: add a time-dependent hyperparameter here to encourage exploration at early steps
                rewards = torch.exp(rewards / torch.max(torch.abs(rewards)))
                rewards = rewards.cpu().numpy()
                weights = rewards / rewards.sum()
                final_sample_indices = (np.random.choice(range(num_samples), size=num_samples, p=weights))
                final_sample_indices.sort()
                sample_indice_traj.append(final_sample_indices)
                updated_molecules = [updated_molecules[i] for i in final_sample_indices]
                updated_molecules = self.collators["molecule"](updated_molecules)
                molecules = updated_molecules
            else:
                t = torch.ones((num_atoms, 1), dtype=torch.float, device=device) * step / self.config.num_sample_steps
                out_molecules = self.estimate_outcome_molecule(molecules, pockets, t)
                out_traj.append(out_molecules)
                molecules['mu_pos'], _ = self.continuous_var_bayesian_update(t, out_molecules['mu_pos'])
                molecules['theta_h'] = self.discrete_var_bayesian_update(t, out_molecules['theta_h'], self.config.ligand_atom_feature_dim)
                # molecules = self.bayesian_update_step(out_molecules, pockets, t)
                if step % 1 == 0:
                    rewards = []
                    t = torch.ones((num_atoms, 1), dtype=torch.float, device=device) * step / self.config.num_sample_steps
                    outcome_molecules = self.estimate_outcome_molecule(molecules, pockets, t)
                    for i in range(num_samples):
                        cur_molecule = self.decode_ith_molecule(outcome_molecules, pockets, i)
                        rewards.append(reward_fn(cur_molecule))
                    rewards = torch.tensor(rewards, dtype=torch.float, device=device)
                    reward_traj.append(rewards)
                    print(rewards)


        # select the best molecule
        # best_molecule = decoded_outcome_molecules[np.argmax(rewards)]
        decoded_outcome_molecules = []
        t = torch.ones((num_atoms, 1), dtype=torch.float, device=device)
        outcome_molecules = self.estimate_outcome_molecule(molecules, pockets, t)
        for i in range(num_samples):
            decoded_outcome_molecules.append(self.decode_ith_molecule(outcome_molecules, pockets, i))
        return decoded_outcome_molecules, in_traj, out_traj, reward_traj, sample_indice_traj
    
    @torch.no_grad()
    def rollout(self, molecules: Featurized[Molecule], pockets: Featurized[Pocket], cur_step: int, reward_fn: Callable=None, rollout_noise: List[Tuple[torch.Tensor, torch.Tensor]]=None) -> List[Molecule]:
        num_atoms = molecules['mu_pos_batch'].shape[0]
        device = molecules['mu_pos'].device
        for step in range(cur_step + 1, self.config.num_sample_steps + 1):
            t = torch.ones((num_atoms, 1), dtype=torch.float, device=device) * (step - 1) / self.config.num_sample_steps
            molecules = self.bayesian_update_step(molecules, pockets, t, rollout_noise[step - cur_step - 1])
        # decode all the molecules and calculate reward
        decoded_outcome_molecules = []
        num_molecules = torch.max(molecules['mu_pos_batch']).item() + 1
        for i in range(num_molecules):
            cur_molecule = self.decode_ith_molecule(molecules, pockets, i)
            decoded_outcome_molecules.append(cur_molecule)
        rewards = []
        for i in range(num_molecules):
            rewards.append(reward_fn(decoded_outcome_molecules[i]))
        return decoded_outcome_molecules, rewards

    @torch.no_grad()
    def sample_intermediate_with_monte_carlo_value(self, pocket: Featurized[Pocket], num_samples: int=1, reward_fn: Callable=None, estimated_ligand_num: List[int]=None) -> List[Molecule]:
        # Create num_samples copies of the pocket
        pockets = self.collators["pocket"]([pocket] * num_samples)
        if estimated_ligand_num is not None:
            pockets["estimated_ligand_num_atoms"] = torch.tensor(estimated_ligand_num, dtype=torch.long, device=pockets["pos"].device)
        molecules = self.create_dummy_molecule(pockets)
        device = molecules['mu_pos'].device
        num_atoms = molecules['mu_pos_batch'].shape[0]
        in_traj, out_traj, value_traj, outcome_rewards_traj, rollout_rewards_traj = [], [], [], [], []

        for step in tqdm(range(1, self.config.num_sample_steps + 1), desc="Sampling"):
            in_traj.append(molecules)
            if step in self.strategy.rollout_steps:
                # rollout
                value_traj.append([])
                rollout_rewards_traj.append([])
                # shoot self.strategy.num_rollout_samples trajectories
                for i in tqdm(range(num_samples), desc="Rollout"):
                    # extract each sample
                    molecule = self.select_ith_molecule(molecules, i)
                    # repeat num_rollout_samples times
                    rollout_molecule = self.collators["molecule"]([molecule] * self.strategy.num_rollout_samples)
                    rollout_pockets = self.collators["pocket"]([pocket] * self.strategy.num_rollout_samples)
                    # to reduce variance, we fix noise the same for each molecule
                    rollout_noise = self.prepare_noise(step, rollout_molecule)
                    _, rollout_rewards = self.rollout(rollout_molecule, rollout_pockets, step, reward_fn, rollout_noise)
                    value_traj[-1].append(np.mean(rollout_rewards))
                    rollout_rewards_traj[-1].append(rollout_rewards)
                    print(rollout_rewards)
                # TODO: resample the trajectories
            t = torch.ones((num_atoms, 1), dtype=torch.float, device=device) * (step - 1) / self.config.num_sample_steps
            out_molecules = self.estimate_outcome_molecule(molecules, pockets, t)
            out_traj.append(out_molecules)
            if step in self.strategy.rollout_steps:
                outcome_molecule_rewards = []
                for i in range(num_samples):
                    cur_molecule = self.decode_ith_molecule(out_molecules, pockets, i)
                    outcome_molecule_rewards.append(reward_fn(cur_molecule))
                outcome_molecule_rewards = torch.tensor(outcome_molecule_rewards, dtype=torch.float)
                outcome_rewards_traj.append(outcome_molecule_rewards)
                print(outcome_molecule_rewards)
            molecules['mu_pos'], _ = self.continuous_var_bayesian_update(t, out_molecules['mu_pos'])
            molecules['theta_h'] = self.discrete_var_bayesian_update(t, out_molecules['theta_h'], self.config.ligand_atom_feature_dim)
        output_molecules = [self.decode_ith_molecule(molecules, pockets, i) for i in range(num_samples)]        
        return output_molecules, in_traj, out_traj, value_traj, rollout_rewards_traj, outcome_rewards_traj

    @torch.no_grad()
    def sample_value_guidance(self, pocket: Featurized[Pocket], num_samples: int=1, reward_fn: Callable=None, estimated_ligand_num: List[int]=None) -> List[Molecule]:
        # import pdb; pdb.set_trace()
        pockets = self.collators["pocket"]([pocket] * num_samples)
        # print(pockets["estimated_ligand_num_atoms"])
        if estimated_ligand_num is not None:
            pockets["estimated_ligand_num_atoms"] = torch.tensor(estimated_ligand_num, dtype=torch.long, device=pockets["pos"].device)
        molecules = self.create_dummy_molecule(pockets)
        device = molecules['mu_pos'].device
        num_atoms = molecules['mu_pos_batch'].shape[0]
        in_traj, out_traj, value_traj, grad_traj_theta_h, grad_traj_mu_pos = [], [], [], [], []
        for step in tqdm(range(1, self.config.num_sample_steps + 1), desc="Sampling"):
            # in_traj.append(molecules)
            
            t = torch.ones((num_atoms, 1), dtype=torch.float, device=device) * (step - 1) / self.config.num_sample_steps
            out_molecules = self.estimate_outcome_molecule(molecules, pockets, t)
            # out_traj.append(out_molecules)
            if step > self.strategy.exploit_step:
                with torch.enable_grad():
                    molecules["mu_pos"] = molecules["mu_pos"].clone().detach().requires_grad_(True)
                    molecules["theta_h"] = molecules["theta_h"].clone().detach().requires_grad_(True)
                    estimated_value = self.value_model.predict(pockets, molecules, t)
                    estimated_value = 0.5 * estimated_value.log()[:, 0] + 0.5 * estimated_value.log()[:, 2]
                    grad_theta_h = torch.autograd.grad(estimated_value, molecules["theta_h"], grad_outputs=torch.ones_like(estimated_value), retain_graph=True)[0]
                    grad_mu_pos = torch.autograd.grad(estimated_value, molecules["mu_pos"], grad_outputs=torch.ones_like(estimated_value), retain_graph=True)[0]
                    value_traj.append(estimated_value.detach().cpu())
                    grad_traj_theta_h.append(grad_theta_h.detach().cpu())
                    grad_traj_mu_pos.append(grad_mu_pos.detach().cpu())
                # print(torch.cuda.memory_allocated() / 1024 ** 3, torch.cuda.memory_reserved() / 1024 ** 3)
                gamma_coord = 1 - torch.pow(self.sigma1_coord, 2 * t)
                out_molecules['mu_pos'] = out_molecules['mu_pos'] + self.strategy.guide_weight * grad_mu_pos * (1 - gamma_coord)
                out_molecules['theta_h'] = out_molecules['theta_h'] + self.strategy.guide_weight * grad_theta_h
            t = torch.ones((num_atoms, 1), dtype=torch.float, device=device) * step / self.config.num_sample_steps
            molecules['mu_pos'], _ = self.continuous_var_bayesian_update(t, out_molecules['mu_pos'], rollout_noise=torch.randn_like(out_molecules['mu_pos']) * self.strategy.noise_scale)
            molecules['theta_h'] = self.discrete_var_bayesian_update(t, out_molecules['theta_h'], self.config.ligand_atom_feature_dim, rollout_noise=torch.randn_like(out_molecules['theta_h']) * self.strategy.noise_scale)
            
        # final step
        t = torch.ones((num_atoms, 1), dtype=torch.float, device=device) * (step - 1) / self.config.num_sample_steps
        out_molecules = self.estimate_outcome_molecule(molecules, pockets, t)
        with torch.enable_grad():
            molecules["mu_pos"] = molecules["mu_pos"].clone().detach().requires_grad_(True)
            molecules["theta_h"] = molecules["theta_h"].clone().detach().requires_grad_(True)
            estimated_value = self.value_model.predict(pockets, molecules, t)
            # estimated_value = (torch.tensor(self.strategy.coef, device=device) * estimated_value.log()).sum(dim=-1)
            estimated_value = 0.5 * estimated_value.log()[:, 0] + 0.5 * estimated_value.log()[:, 2]
            grad_theta_h = torch.autograd.grad(estimated_value, molecules["theta_h"], grad_outputs=torch.ones_like(estimated_value), retain_graph=True)[0]
            grad_mu_pos = torch.autograd.grad(estimated_value, molecules["mu_pos"], grad_outputs=torch.ones_like(estimated_value), retain_graph=True)[0]
        out_molecules['mu_pos'] = out_molecules['mu_pos'] + self.strategy.guide_weight * grad_mu_pos * self.sigma1_coord**(step / self.config.num_sample_steps)
        out_molecules['theta_h'] = out_molecules['theta_h'] + self.strategy.guide_weight * grad_theta_h
        out_traj.append(out_molecules)
        grad_traj_theta_h.append(grad_theta_h.detach().cpu())
        grad_traj_mu_pos.append(grad_mu_pos.detach().cpu())
        output_molecules = [self.decode_ith_molecule(out_molecules, pockets, i) for i in range(num_samples)]
        return output_molecules, in_traj, out_traj, value_traj, grad_traj_theta_h, grad_traj_mu_pos

    @torch.no_grad()
    def sample_svdd_with_value_guidance(self, pocket: Featurized[Pocket], num_samples: int=1, reward_fn: Callable=None, estimated_ligand_num: List[int]=None) -> List[Molecule]:
        pockets = self.collators["pocket"]([pocket] * num_samples)
        if estimated_ligand_num is not None:
            pockets["estimated_ligand_num_atoms"] = torch.tensor(estimated_ligand_num, dtype=torch.long, device=pockets["pos"].device)
        molecules = self.create_dummy_molecule(pockets)
        device = molecules['mu_pos'].device
        num_atoms = molecules['mu_pos_batch'].shape[0]
        in_traj, out_traj, value_traj, svdd_rewards_traj, svdd_sample_indice_traj = [], [], [], [], []
        for step in tqdm(range(1, self.config.num_sample_steps + 1), desc="Sampling"):
            in_traj.append(molecules)
            t = torch.ones((num_atoms, 1), dtype=torch.float, device=device) * (step - 1) / self.config.num_sample_steps
            out_molecules = self.estimate_outcome_molecule(molecules, pockets, t)
            out_traj.append(out_molecules)
            # value guidance
            with torch.enable_grad():
                molecules["mu_pos"] = molecules["mu_pos"].clone().detach().requires_grad_(True)
                molecules["theta_h"] = molecules["theta_h"].clone().detach().requires_grad_(True)
                estimated_value = self.value_model.predict(pockets, molecules, t)
                estimated_value = 0.5 * estimated_value.log()[:, 0] + 0.5 * estimated_value.log()[:, 2]
                value_traj.append(estimated_value.detach().cpu())
                grad_theta_h = torch.autograd.grad(estimated_value, molecules["theta_h"], grad_outputs=torch.ones_like(estimated_value), retain_graph=True)[0]
                grad_mu_pos = torch.autograd.grad(estimated_value, molecules["mu_pos"], grad_outputs=torch.ones_like(estimated_value), retain_graph=True)[0]
            out_molecules['mu_pos'] = out_molecules['mu_pos'] + self.strategy.guide_weight * grad_mu_pos * self.sigma1_coord**(step / self.config.num_sample_steps)
            out_molecules['theta_h'] = out_molecules['theta_h'] + self.strategy.guide_weight * grad_theta_h
            
            # SVDD
            if step % 10 == 0 and step >= 40 and step <= 70:
                rewards = []
                for i in range(num_samples):
                    cur_molecule = self.decode_ith_molecule(out_molecules, pockets, i)
                    rewards.append(reward_fn(cur_molecule))
                rewards = torch.tensor(rewards, dtype=torch.float, device=device)
                svdd_rewards_traj.append(rewards)
                temp = self.strategy.temperature / ((step - 40) / (self.config.num_sample_steps - 40) + 1e-6)
                weights = torch.exp((rewards - torch.max(rewards).item()) / temp).cpu().numpy()
                weights = weights / weights.sum()
                final_sample_indices = (np.random.choice(range(num_samples), size=num_samples, p=weights))
                final_sample_indices.sort()
                svdd_sample_indice_traj.append(final_sample_indices)
                out_molecules = [self.select_ith_molecule(out_molecules, i) for i in final_sample_indices]
                out_molecules = self.collators["molecule"](out_molecules)
            
            # Bayesian update
            t = torch.ones((num_atoms, 1), dtype=torch.float, device=device) * step / self.config.num_sample_steps
            molecules['mu_pos'], _ = self.continuous_var_bayesian_update(t, out_molecules['mu_pos'])
            molecules['theta_h'] = self.discrete_var_bayesian_update(t, out_molecules['theta_h'], self.config.ligand_atom_feature_dim)
        
        # final step
        t = torch.ones((num_atoms, 1), dtype=torch.float, device=device) 
        out_molecules = self.estimate_outcome_molecule(molecules, pockets, t)
        with torch.enable_grad():
            molecules["mu_pos"] = molecules["mu_pos"].clone().detach().requires_grad_(True)
            molecules["theta_h"] = molecules["theta_h"].clone().detach().requires_grad_(True)
            estimated_value = self.value_model.predict(pockets, molecules, t)
            estimated_value = 0.35 * estimated_value.log()[:, 0] + 0.65 * estimated_value.log()[:, 2]
            grad_theta_h = torch.autograd.grad(estimated_value, molecules["theta_h"], grad_outputs=torch.ones_like(estimated_value), retain_graph=True)[0]
            grad_mu_pos = torch.autograd.grad(estimated_value, molecules["mu_pos"], grad_outputs=torch.ones_like(estimated_value), retain_graph=True)[0]
        out_molecules['mu_pos'] = out_molecules['mu_pos'] + self.strategy.guide_weight * grad_mu_pos * self.sigma1_coord**2
        out_molecules['theta_h'] = out_molecules['theta_h'] + self.strategy.guide_weight * grad_theta_h
        output_molecules = [self.decode_ith_molecule(out_molecules, pockets, i) for i in range(num_samples)]
        return output_molecules, in_traj, out_traj, value_traj, svdd_rewards_traj, svdd_sample_indice_traj
    
    @torch.no_grad()
    def sample_classifier_free_guidance(self, pocket: Featurized[Pocket], num_samples: int=1, reward_fn: Callable=None, estimated_ligand_num: List[int]=None) -> List[Molecule]:
        pockets = self.collators["pocket"]([pocket] * num_samples)
        if estimated_ligand_num is not None:
            pockets["estimated_ligand_num_atoms"] = torch.tensor(estimated_ligand_num, dtype=torch.long, device=pockets["pos"].device)
        molecules = self.create_dummy_molecule(pockets)
        if getattr(self.strategy, "solver", None) == "ode":
            self.prepare_constants()
            std = self.config.ligand_atom_feature_dim ** -0.5 * self.config.beta1 * 1e-5
            molecules["z_h"] = torch.randn_like(molecules["theta_h"]) * std
        device = molecules['mu_pos'].device
        num_atoms = molecules['mu_pos_batch'].shape[0]
        w = self.strategy.guide_weight
        in_traj, out_traj, bad_traj = [], [], []
        for step in tqdm(range(1, self.config.num_sample_steps + 1), desc="Sampling"):
            in_traj.append({
                "mu_pos": molecules["mu_pos"].cpu().clone(),
                "theta_h": molecules["theta_h"].cpu().clone(),
            })
            t = torch.ones((num_atoms, 1), dtype=torch.float, device=device) * (step - 1) / self.config.num_sample_steps
            gamma_coord = 1 - torch.pow(self.sigma1_coord, 2 * t)
            if getattr(self.strategy, "explore_end_step", None) and step >= self.strategy.explore_end_step:
                coord_pred, p0_h, _ = self.bad_model.interdependency_modeling(
                    time=t,
                    protein_pos=pockets['pos'],
                    protein_v=pockets['atom_feature'],
                    batch_protein=pockets['pos_batch'],
                    batch_ligand=molecules['mu_pos_batch'],
                    theta_h_t=molecules['theta_h'],
                    mu_pos_t=molecules['mu_pos'],
                    gamma_coord=gamma_coord,
                    softmax=False,
                )

                out_molecules = {
                    "mu_pos": coord_pred,
                    "theta_h": p0_h,
                }
            else:
                # classifier_input = getattr(self.strategy, "classifier_input_scale", 1) * torch.ones(molecules["mu_pos"].shape[0], 3).to(molecules["mu_pos"].device)
                classifier_input = torch.tensor(getattr(self.strategy, "classifier_input_scale", [1, 1, 1]), dtype=torch.float, device=device).repeat(num_atoms, 1)
                out_molecules = self.estimate_outcome_molecule(molecules, pockets, t, classifier_input, softmax=False)
            if abs(w - 1) > 1e-6:
                if self.strategy.bad_model == "ref":
                    coord_pred, p0_h, _ = self.bad_model.interdependency_modeling(
                        time=t,
                        protein_pos=pockets['pos'],
                        protein_v=pockets['atom_feature'],
                        batch_protein=pockets['pos_batch'],
                        batch_ligand=molecules['mu_pos_batch'],
                        theta_h_t=molecules['theta_h'],
                        mu_pos_t=molecules['mu_pos'],
                        gamma_coord=gamma_coord,
                        softmax=False,
                    )

                    bad_molecules = {
                        "mu_pos": coord_pred,
                        "theta_h": p0_h,
                    }
                else:
                    classifier_input = 0.6 * torch.ones(molecules["mu_pos"].shape[0], 3).to(molecules["mu_pos"].device)
                    bad_molecules = self.estimate_outcome_molecule(molecules, pockets, t, classifier_input)
                bad_traj.append({
                    "mu_pos": bad_molecules["mu_pos"].cpu().clone(),
                    "theta_h": bad_molecules["theta_h"].cpu().clone(),
                })
                """
                print(
                    scatter_mean(torch.cosine_similarity(out_molecules["mu_pos"], bad_molecules["mu_pos"]), out_molecules["mu_pos_batch"]), 
                    scatter_mean(torch.cosine_similarity(out_molecules["theta_h"], bad_molecules["theta_h"]), out_molecules["mu_pos_batch"])
                )
                """
                gamma_coord = 1 - torch.pow(self.sigma1_coord, 2 * t)
                out_molecules["mu_pos"] = out_molecules["mu_pos"] + (w - 1) * gamma_coord* (out_molecules["mu_pos"] - bad_molecules["mu_pos"])
                out_molecules["theta_h"] = out_molecules["theta_h"] + (w - 1) * (out_molecules["theta_h"] - bad_molecules["theta_h"])
            out_traj.append({
                "mu_pos": out_molecules["mu_pos"].cpu().clone(),
                "theta_h": out_molecules["theta_h"].cpu().clone(),
            })
            
            # Bayesian update
            if getattr(self.strategy, "solver", None) == "ode":
                molecules['mu_pos'] = self.ode_bfnsolver1_coord(molecules['mu_pos'], out_molecules['mu_pos'], step - 1)
                molecules['z_h'] = self.ode_bfnsolver1_type(molecules['z_h'], out_molecules['theta_h'], step - 1)
                molecules['theta_h'] = F.softmax(molecules['z_h'], dim=-1)
            elif getattr(self.strategy, "solver", None) == "ode_brute_force":
                t = torch.ones((num_atoms, 1), dtype=torch.float, device=device) * step / self.config.num_sample_steps
                molecules['mu_pos'], _ = self.continuous_var_bayesian_update(t, out_molecules['mu_pos'], rollout_noise=torch.randn_like(out_molecules['mu_pos']) * self.strategy.noise_scale)
                molecules['theta_h'] = self.discrete_var_bayesian_update(t, F.softmax(out_molecules['theta_h'], dim=-1), self.config.ligand_atom_feature_dim, rollout_noise=torch.randn_like(out_molecules['theta_h']) * self.strategy.noise_scale)
            else:
                t = torch.ones((num_atoms, 1), dtype=torch.float, device=device) * step / self.config.num_sample_steps
                molecules['mu_pos'], _ = self.continuous_var_bayesian_update(t, out_molecules['mu_pos'])
                molecules['theta_h'] = self.discrete_var_bayesian_update(t, F.softmax(out_molecules['theta_h'], dim=-1), self.config.ligand_atom_feature_dim)
        
        # final step
        in_traj.append({
            "mu_pos": molecules["mu_pos"].cpu().clone(),
            "theta_h": molecules["theta_h"].cpu().clone(),
        })
        t = torch.ones((num_atoms, 1), dtype=torch.float, device=device)
        out_molecules = self.estimate_outcome_molecule(molecules, pockets, t)
        out_traj.append({
            "mu_pos": out_molecules["mu_pos"].cpu().clone(),
            "theta_h": out_molecules["theta_h"].cpu().clone(),
        })
        """
        if self.strategy.bad_model == "ref":
            bad_molecules = self.bad_model.estimate_outcome_molecule(molecules, pockets, t)
        else:
            classifier_input = torch.zeros(molecules["mu_pos"].shape[0], 3).to(molecules["mu_pos"].device)
            bad_molecules = self.estimate_outcome_molecule(molecules, pockets, t, classifier_input)
        out_molecules["mu_pos"] = w * out_molecules["mu_pos"] + (1 - w) * bad_molecules["mu_pos"]
        out_molecules["theta_h"] = w * out_molecules["theta_h"] + (1 - w) * bad_molecules["theta_h"]
        """

        output_molecules = [self.decode_ith_molecule(out_molecules, pockets, i) for i in range(num_samples)]
        return output_molecules, in_traj, out_traj, bad_traj

    @torch.no_grad()
    def sample_mixed_cg_cfg(self, pocket: Featurized[Pocket], num_samples: int=1, reward_fn: Callable=None, estimated_ligand_num: List[int]=None) -> List[Molecule]:
        pockets = self.collators["pocket"]([pocket] * num_samples)
        if estimated_ligand_num is not None:
            pockets["estimated_ligand_num_atoms"] = torch.tensor(estimated_ligand_num, dtype=torch.long, device=pockets["pos"].device)
        molecules = self.create_dummy_molecule(pockets)
        device = molecules['mu_pos'].device
        num_atoms = molecules['mu_pos_batch'].shape[0]
        w = self.strategy.cfg_guide_weight
        in_traj, out_traj, bad_traj = [], [], []
        for step in range(1, self.config.num_sample_steps + 1):
            in_traj.append(molecules)
            t = torch.ones((num_atoms, 1), dtype=torch.float, device=device) * (step - 1) / self.config.num_sample_steps
            classifier_input = torch.tensor(getattr(self.strategy, "classifier_input_scale", [1, 1, 1]), dtype=torch.float, device=device).repeat(num_atoms, 1)
            out_molecules = self.estimate_outcome_molecule(molecules, pockets, t, classifier_input)
            out_traj.append(out_molecules)
            if step < self.strategy.cfg_end_step:
                # classifier free guidance
                if abs(w - 1) > 1e-6:
                    if self.strategy.bad_model == "ref":
                        gamma_coord = 1 - torch.pow(self.sigma1_coord, 2 * t)
                        coord_pred, p0_h, _ = self.bad_model.interdependency_modeling(
                            time=t,
                            protein_pos=pockets['pos'],
                            protein_v=pockets['atom_feature'],
                            batch_protein=pockets['pos_batch'],
                            batch_ligand=molecules['mu_pos_batch'],
                            theta_h_t=molecules['theta_h'],
                            mu_pos_t=molecules['mu_pos'],
                            gamma_coord=gamma_coord,
                        )

                        bad_molecules = {
                            "mu_pos": coord_pred,
                            "theta_h": p0_h,
                        }
                    else:
                        classifier_input = 0.8 * torch.ones(molecules["mu_pos"].shape[0], 3).to(molecules["mu_pos"].device)
                        bad_molecules = self.estimate_outcome_molecule(molecules, pockets, t, classifier_input)
                    bad_traj.append(bad_molecules)
                    # print(torch.cosine_similarity(out_molecules["mu_pos"], bad_molecules["mu_pos"]).mean(), torch.cosine_similarity(out_molecules["theta_h"], bad_molecules["theta_h"]).mean())
                    gamma_coord = 1 - torch.pow(self.sigma1_coord, 2 * t)
                    out_molecules["mu_pos"] = out_molecules["mu_pos"] + (w - 1) * gamma_coord* (out_molecules["mu_pos"] - bad_molecules["mu_pos"])
                    out_molecules["theta_h"] = out_molecules["theta_h"] + (w - 1) * (out_molecules["theta_h"] - bad_molecules["theta_h"])
                noise_mu_pos = torch.randn_like(out_molecules["mu_pos"]) * self.strategy.noise_scale
                noise_theta_h = torch.randn_like(out_molecules["theta_h"]) * self.strategy.noise_scale
            else:
                # classifier guidance
                with torch.enable_grad():
                    molecules["mu_pos"] = molecules["mu_pos"].clone().detach().requires_grad_(True)
                    molecules["theta_h"] = molecules["theta_h"].clone().detach().requires_grad_(True)
                    estimated_value = self.value_model.predict(pockets, molecules, t)
                    estimated_value = 0.5 * estimated_value.log()[:, 0] + 0.5 * estimated_value.log()[:, 2]
                    grad_theta_h = torch.autograd.grad(estimated_value, molecules["theta_h"], grad_outputs=torch.ones_like(estimated_value), retain_graph=True)[0]
                    grad_mu_pos = torch.autograd.grad(estimated_value, molecules["mu_pos"], grad_outputs=torch.ones_like(estimated_value), retain_graph=True)[0]
                out_molecules["mu_pos"] = out_molecules["mu_pos"] + self.strategy.cg_guide_weight * grad_mu_pos * self.sigma1_coord**(step / self.config.num_sample_steps)
                out_molecules["theta_h"] = out_molecules["theta_h"] + self.strategy.cg_guide_weight * grad_theta_h

                noise_mu_pos = torch.randn_like(out_molecules["mu_pos"]) * self.strategy.noise_scale
                noise_theta_h = torch.randn_like(out_molecules["theta_h"]) * self.strategy.noise_scale
            # Bayesian update
            t = torch.ones((num_atoms, 1), dtype=torch.float, device=device) * step / self.config.num_sample_steps
            molecules['mu_pos'], _ = self.continuous_var_bayesian_update(t, out_molecules['mu_pos'], rollout_noise=noise_mu_pos)
            molecules['theta_h'] = self.discrete_var_bayesian_update(t, out_molecules['theta_h'], self.config.ligand_atom_feature_dim, rollout_noise=noise_theta_h)

        # final step
        t = torch.ones((num_atoms, 1), dtype=torch.float, device=device)
        out_molecules = self.estimate_outcome_molecule(molecules, pockets, t)
        with torch.enable_grad():
            molecules["mu_pos"] = molecules["mu_pos"].clone().detach().requires_grad_(True)
            molecules["theta_h"] = molecules["theta_h"].clone().detach().requires_grad_(True)
            estimated_value = self.value_model.predict(pockets, molecules, t)
            estimated_value = 0.5 * estimated_value.log()[:, 0] + 0.5 * estimated_value.log()[:, 2]
            grad_theta_h = torch.autograd.grad(estimated_value, molecules["theta_h"], grad_outputs=torch.ones_like(estimated_value), retain_graph=True)[0]
            grad_mu_pos = torch.autograd.grad(estimated_value, molecules["mu_pos"], grad_outputs=torch.ones_like(estimated_value), retain_graph=True)[0]
        out_molecules["mu_pos"] = out_molecules["mu_pos"] + self.strategy.cg_guide_weight * grad_mu_pos * self.sigma1_coord**2
        out_molecules["theta_h"] = out_molecules["theta_h"] + self.strategy.cg_guide_weight * grad_theta_h

        output_molecules = [self.decode_ith_molecule(out_molecules, pockets, i) for i in range(num_samples)]
        return output_molecules, in_traj, out_traj, bad_traj