from open_biomed.models.molecule.molcraft import MolCRAFT
from open_biomed.utils.config import Config
from open_biomed.data import Molecule, Pocket
from open_biomed.utils.featurizer import Featurized
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
import copy
from tqdm import tqdm
from samplers.sbdd.bfn_sampler import BFNSampleStrategies
from typing import Tuple, List, Dict

class MolCRAFT4AdvancedSampling(MolCRAFT, BFNSampleStrategies):
    def __init__(self, model_cfg: Config, strategy: Config) -> None:
        # Initalize the MolCRAFT class
        MolCRAFT.__init__(self, model_cfg)
        # Initialize the BFNSampleStrategies class
        BFNSampleStrategies.__init__(self, strategy)
        self.configure_task(task="structure_based_drug_design")

    def prepare_noise(self, step: int, molecules: Featurized[Molecule]) -> List[Tuple[torch.Tensor, torch.Tensor]]:
        all_noise = []
        for i in range(step + 1, self.config.num_sample_steps + 1):
            all_noise.append((torch.randn_like(molecules['mu_pos']), torch.randn_like(molecules['theta_h'])))
        return all_noise

    def continuous_var_bayesian_update(self, t: torch.Tensor, x: torch.Tensor, rollout_noise: torch.Tensor=None) -> torch.Tensor:
        if rollout_noise is None:
            return MolCRAFT.continuous_var_bayesian_update(self, t, x)
        else:
            gamma = (1 - torch.pow(self.sigma1_coord, 2 * t))  # [B]
            mu = gamma * x + torch.sqrt(gamma* (1 - gamma)) * rollout_noise
            return mu, gamma
        
    def discrete_var_bayesian_update(self, t: torch.Tensor, x: torch.Tensor, K: int, rollout_noise: torch.Tensor=None, softmax: bool=True) -> torch.Tensor:
        if rollout_noise is None:
            return MolCRAFT.discrete_var_bayesian_update(self, t, x, K)
        else:
            beta = (self.config.beta1 * (t**2))  # (B,)

            # Eq.(185): p_F(θ|x;t) = E_{N(y | β(t)(Ke_x−1), β(t)KI)} δ (θ − softmax(y))
            # can be sampled by first drawing y ~ N(y | β(t)(Ke_x−1), β(t)KI)
            # then setting θ = softmax(y)
            one_hot_x = x  # (N, K)
            mean = beta * (K * one_hot_x - 1)
            std = (beta * K).sqrt()
            eps = rollout_noise
            y = mean + std * eps
            if softmax:
                theta = F.softmax(y, dim=-1)
            else:
                theta = y
            return theta
    

    def bayesian_update_step(self, molecules: Featurized[Molecule], pockets: Featurized[Pocket], t: torch.Tensor, rollout_noise: Tuple[torch.Tensor, torch.Tensor]=None) -> Featurized[Molecule]:
        gamma_coord = 1 - torch.pow(self.sigma1_coord, 2 * t)
        coord_pred, p0_h, _ = self.interdependency_modeling(
            time=t,
            protein_pos=pockets['pos'],
            protein_v=pockets['atom_feature'],
            batch_protein=pockets['pos_batch'],
            batch_ligand=molecules['mu_pos_batch'],
            theta_h_t=molecules['theta_h'],
            mu_pos_t=molecules['mu_pos'],
            gamma_coord=gamma_coord,
        )

        output_molecules = copy.deepcopy(molecules)
        output_molecules['mu_pos'], _ = self.continuous_var_bayesian_update(t, coord_pred, None if rollout_noise is None else rollout_noise[0])
        output_molecules['theta_h'] = self.discrete_var_bayesian_update(t, p0_h, self.config.ligand_atom_feature_dim, None if rollout_noise is None else rollout_noise[1])
        return output_molecules

    def estimate_outcome_molecule(self, molecules: Featurized[Molecule], pockets: Featurized[Pocket], t: torch.Tensor) -> Featurized[Molecule]:
        gamma_coord = 1 - torch.pow(self.sigma1_coord, 2 * t)
        coord_pred, p0_h, _ = self.interdependency_modeling(
            time=t,
            protein_pos=pockets['pos'],
            protein_v=pockets['atom_feature'],
            batch_protein=pockets['pos_batch'],
            batch_ligand=molecules['mu_pos_batch'],
            theta_h_t=molecules['theta_h'],
            mu_pos_t=molecules['mu_pos'],
            gamma_coord=gamma_coord,
        )

        output_molecules = copy.deepcopy(molecules)
        output_molecules['mu_pos'] = coord_pred
        output_molecules['theta_h'] = p0_h
        return output_molecules

    def select_ith_molecule(self, molecules: Featurized[Molecule], i: int) -> Featurized[Molecule]:
        indices = torch.where(molecules["mu_pos_batch"] == i)[0]
        return Data(**{
            "mu_pos": molecules["mu_pos"][indices],
            "theta_h": molecules["theta_h"][indices],
        })
    
    def decode_ith_molecule(self, molecules: Featurized[Molecule], pockets: Featurized[Pocket], i: int) -> Molecule:
        molecule = self.select_ith_molecule(molecules, i)
        return self.featurizers["molecule"].decode({
            "pos": molecule["mu_pos"],
            "atom_type": molecule["theta_h"].argmax(dim=-1),
        }, pockets["pocket_center"][i])

class MolCRAFTWithCFG(MolCRAFT):
    def __init__(self, model_cfg: Config) -> None:
        super(MolCRAFTWithCFG, self).__init__(model_cfg)
        if getattr(model_cfg, "discrete", False):
            self.classifier_emb = nn.Linear(30, model_cfg.hidden_dim - 1)
        else:
            self.ligand_atom_emb = nn.Linear(model_cfg.ligand_atom_feature_dim + 4, model_cfg.hidden_dim - 1)

    def interdependency_modeling(
        self,
        time,
        protein_pos,  # transform from the orginal BFN codebase
        protein_v,  # transform from
        batch_protein,  # index for protein
        theta_h_t,
        mu_pos_t,
        batch_ligand,  # index for ligand
        gamma_coord,
        classfier_input,
        return_all=False,  # legacy from targetdiff
        fix_x=False,
        softmax=True,
    ):
        """
        Compute output distribution parameters for p_O (x' | θ; t) (x_hat or k^(d) logits).
        Draw output_sample = x' ~ p_O (x' | θ; t).
            continuous x ~ δ(x - x_hat(θ, t))
            discrete k^(d) ~ softmax(Ψ^(d)(θ, t))_k
        Args:
            time: [node_num x batch_size, 1] := [N_ligand, 1]
            protein_pos: [node_num x batch_size, 3] := [N_protein, 3]
            protein_v: [node_num x batch_size, protein_atom_feature_dim] := [N_protein, 27]
            batch_protein: [node_num x batch_size] := [N_protein]
            theta_h_t: [node_num x batch_size, atom_type] := [N_ligand, 13]
            mu_pos_t: [node_num x batch_size, 3] := [N_ligand, 3]
            batch_ligand: [node_num x batch_size] := [N_ligand]
            gamma_coord: [node_num x batch_size, 1] := [N_ligand, 1]
            classfier_input: [batch_size, 3] := [batch_size, 4]
        """
        theta_h_t = 2 * theta_h_t - 1  # from 1/K \in [0,1] to 2/K-1 \in [-1,1]

        # ---------for targetdiff-----------
        init_ligand_v = theta_h_t
        # time embedding
        time_emb = self.time_emb_layer(time)
        classfier_input = classfier_input.index_select(0, batch_ligand)
        if getattr(self.config, "discrete", False):
            input_ligand_feat = torch.cat([init_ligand_v, time_emb], -1)
        else:
            input_ligand_feat = torch.cat([init_ligand_v, time_emb, classfier_input], -1)

        h_protein = self.protein_atom_emb(protein_v)  # [N_protein, self.hidden_dim - 1]
        init_ligand_h = self.ligand_atom_emb(input_ligand_feat)  # [N_ligand, self.hidden_dim - 1]
        if getattr(self.config, "discrete", False):
            classfier_input = F.one_hot(classfier_input, num_classes=10).view(classfier_input.shape[0], -1).float()
            init_ligand_h += self.classifier_emb(classfier_input)

        if self.node_indicator:
            h_protein = torch.cat(
                [h_protein, torch.zeros(len(h_protein), 1).to(h_protein)], -1
            )  # [N_ligand, self.hidden_dim]
            init_ligand_h = torch.cat(
                [init_ligand_h, torch.ones(len(init_ligand_h), 1).to(h_protein)], -1
            )  # [N_ligand, self.hidden_dim]

        h_all, pos_all, batch_all, mask_ligand = self.compose_context(
            h_protein=h_protein,
            h_ligand=init_ligand_h,
            pos_protein=protein_pos,
            pos_ligand=mu_pos_t,
            batch_protein=batch_protein,
            batch_ligand=batch_ligand,
        )
        # get the context for the protein and ligand, while the ligand is h is noisy (h_t)/ pos is also the noise version. (pos_t)

        # time = 2 * time - 1
        outputs = self.unio2net(
            h_all, pos_all, mask_ligand, batch_all, return_all=return_all, fix_x=fix_x
        )
        final_pos, final_h = outputs["x"], outputs["h"]
        final_ligand_pos, final_ligand_h = final_pos[mask_ligand], final_h[mask_ligand]
        final_ligand_v = self.v_inference(final_ligand_h)  # [N_ligand, 13]

        # 1. for continuous, network outputs eps_hat(θ, t)
        # Eq.(84): x_hat(θ, t) = μ / γ(t) − \sqrt{(1 − γ(t)) / γ(t)} * eps_hat(θ, t)
        # 2. for discrete, network outputs Ψ(θ, t)
        # take softmax will do
        return final_ligand_pos, F.softmax(final_ligand_v, dim=-1) if softmax else final_ligand_v, torch.zeros_like(mu_pos_t)
    
    def forward_with_t(self, pocket: Featurized[Pocket], molecule: Featurized[Molecule], classifier_input: torch.Tensor, t: torch.Tensor) -> Dict[str, torch.Tensor]:
        gamma_coord = 1 - torch.pow(self.sigma1_coord, 2 * t)
        coord_pred, p0_h, _ = self.interdependency_modeling(
            time=t,
            protein_pos=pocket['pos'],
            protein_v=pocket['atom_feature'],
            batch_protein=pocket['pos_batch'],
            batch_ligand=molecule['pos_batch'],
            theta_h_t=molecule['theta_h'],
            mu_pos_t=molecule['mu_pos'],
            gamma_coord=gamma_coord,
            classfier_input=classifier_input,
        )
        loss_coord = self.continuous_loss(t, molecule["pos"], coord_pred)
        loss_atom_type = self.discrete_loss(t, molecule["atom_feature"], p0_h)
        return {
            "loss": loss_coord + loss_atom_type,
            "loss_coord": loss_coord,
            "loss_atom_type": loss_atom_type,
        }
    
    def sample(self, pocket: Featurized[Pocket], molecule: Featurized[Molecule], classifier_input: torch.Tensor) -> List[Molecule]:
        in_traj, out_traj = [], []
        device = molecule['mu_pos'].device
        num_atoms = molecule['mu_pos_batch'].shape[0]

        for step in tqdm(range(1, self.config.num_sample_steps + 1), desc="Sampling"):
            t = torch.ones((num_atoms, 1), dtype=torch.float, device=device) * (step - 1) / self.config.num_sample_steps
            gamma_coord = 1 - torch.pow(self.sigma1_coord, 2 * t)
            in_traj.append((molecule["mu_pos"].clone(), molecule["theta_h"].clone()))
            coord_pred, p0_h, _ = self.interdependency_modeling(
                time=t,
                protein_pos=pocket['pos'],
                protein_v=pocket['atom_feature'],
                batch_protein=pocket['pos_batch'],
                batch_ligand=molecule['mu_pos_batch'],
                theta_h_t=molecule['theta_h'],
                mu_pos_t=molecule['mu_pos'],
                gamma_coord=gamma_coord,
                classfier_input=classifier_input,
            )
            out_traj.append((coord_pred.detach().clone(), p0_h.detach().clone()))

            t = torch.ones((num_atoms, 1), dtype=torch.float, device=device) * step / self.config.num_sample_steps
            molecule['theta_h'] = self.discrete_var_bayesian_update(t, p0_h, self.config.ligand_atom_feature_dim)
            molecule['mu_pos'], _ = self.continuous_var_bayesian_update(t, coord_pred)

        # Compute final output distribution parameters for p_O (x' | θ; t)
        in_traj.append((molecule["mu_pos"].detach().clone(), molecule["theta_h"].detach().clone()))
        mu_pos_final, p0_h_final, _ = self.interdependency_modeling(
            time=torch.ones((num_atoms, 1)).to(device),
            protein_pos=pocket['pos'],
            protein_v=pocket['atom_feature'],
            batch_protein=pocket['pos_batch'],
            batch_ligand=molecule['mu_pos_batch'],
            theta_h_t=molecule['theta_h'],
            mu_pos_t=molecule['mu_pos'],
            gamma_coord=1 - self.sigma1_coord ** 2,  # γ(t) = 1 − (σ1**2) ** t
            classfier_input=classifier_input,
        )
        p0_h_final = torch.clamp(p0_h_final, min=1e-6)
        out_traj.append((mu_pos_final.detach().clone(), p0_h_final.detach().clone()))

        num_mols = molecule['mu_pos_batch'].max() + 1
        in_traj_split, out_traj_split = [], []
        out_molecules = []
        for i in range(num_mols):
            cur_molecule = {}
            idx = torch.where(molecule['mu_pos_batch'] == i)[0]
            in_traj_split.append({
                "pos": torch.stack([in_traj[j][0][idx] for j in range(len(in_traj))], dim=0),
                "atom_type": torch.stack([in_traj[j][1][idx] for j in range(len(in_traj))], dim=0),
            })
            out_traj_split.append({
                "pos": torch.stack([out_traj[j][0][idx] for j in range(len(out_traj))], dim=0),
                "atom_type": torch.stack([out_traj[j][1][idx] for j in range(len(out_traj))], dim=0),
            })
            cur_molecule = {
                "pos": out_traj_split[i]["pos"][-1],
                "atom_type": torch.argmax(out_traj_split[i]["atom_type"][-1], dim=-1),
            }
            out_molecules.append(self.featurizers["molecule"].decode(cur_molecule, pocket["pocket_center"][i]))
        return out_molecules
    
    def predict_structure_based_drug_design(self, pocket: Featurized[Pocket], classifier_input: torch.Tensor) -> List[Molecule]:
        molecule = self.create_dummy_molecule(pocket)
        return self.sample(pocket, molecule, classifier_input)

class MolCRAFTWithCFG4AdvancedSampling(MolCRAFTWithCFG, BFNSampleStrategies):
    def __init__(self, model_cfg: Config, strategy: Config) -> None:
        super(MolCRAFTWithCFG4AdvancedSampling, self).__init__(model_cfg)
        BFNSampleStrategies.__init__(self, strategy)
        self.configure_task(task="structure_based_drug_design")

    def estimate_outcome_molecule(self, molecules: Featurized[Molecule], pockets: Featurized[Pocket], t: torch.Tensor, classifier_input: torch.Tensor=None, softmax: bool=True) -> Featurized[Molecule]:
        gamma_coord = 1 - torch.pow(self.sigma1_coord, 2 * t)
        if classifier_input is None:
            classifier_input = torch.ones(molecules["mu_pos"].shape[0], 3).to(molecules["mu_pos"].device)
        coord_pred, p0_h, _ = self.interdependency_modeling(
            time=t,
            protein_pos=pockets['pos'],
            protein_v=pockets['atom_feature'],
            batch_protein=pockets['pos_batch'],
            batch_ligand=molecules['mu_pos_batch'],
            theta_h_t=molecules['theta_h'],
            mu_pos_t=molecules['mu_pos'],
            gamma_coord=gamma_coord,
            classfier_input=classifier_input,
            softmax=softmax,
        )

        output_molecules = copy.deepcopy(molecules)
        output_molecules['mu_pos'] = coord_pred
        output_molecules['theta_h'] = p0_h
        return output_molecules
    
    def bayesian_update_step(self, molecules: Featurized[Molecule], pockets: Featurized[Pocket], t: torch.Tensor, rollout_noise: Tuple[torch.Tensor, torch.Tensor]=None) -> Featurized[Molecule]:
        gamma_coord = 1 - torch.pow(self.sigma1_coord, 2 * t)
        classifier_input = torch.ones(molecules["mu_pos"].shape[0], 3).to(molecules["mu_pos"].device)
        coord_pred, p0_h, _ = self.interdependency_modeling(
            time=t,
            protein_pos=pockets['pos'],
            protein_v=pockets['atom_feature'],
            batch_protein=pockets['pos_batch'],
            batch_ligand=molecules['mu_pos_batch'],
            theta_h_t=molecules['theta_h'],
            mu_pos_t=molecules['mu_pos'],
            gamma_coord=gamma_coord,
            classfier_input=classifier_input,
        )

        output_molecules = copy.deepcopy(molecules)
        output_molecules['mu_pos'], _ = self.continuous_var_bayesian_update(t, coord_pred, None if rollout_noise is None else rollout_noise[0])
        output_molecules['theta_h'] = self.discrete_var_bayesian_update(t, p0_h, self.config.ligand_atom_feature_dim, None if rollout_noise is None else rollout_noise[1])
        return output_molecules

    def select_ith_molecule(self, molecules: Featurized[Molecule], i: int) -> Featurized[Molecule]:
        indices = torch.where(molecules["mu_pos_batch"] == i)[0]
        return Data(**{
            "mu_pos": molecules["mu_pos"][indices],
            "theta_h": molecules["theta_h"][indices],
        })
    
    def decode_ith_molecule(self, molecules: Featurized[Molecule], pockets: Featurized[Pocket], i: int) -> Molecule:
        molecule = self.select_ith_molecule(molecules, i)
        return self.featurizers["molecule"].decode({
            "pos": molecule["mu_pos"],
            "atom_type": molecule["theta_h"].argmax(dim=-1),
        }, pockets["pocket_center"][i])

    def continuous_var_bayesian_update(self, t: torch.Tensor, x: torch.Tensor, rollout_noise: torch.Tensor=None) -> torch.Tensor:
        if rollout_noise is None:
            return MolCRAFT.continuous_var_bayesian_update(self, t, x)
        else:
            gamma = (1 - torch.pow(self.sigma1_coord, 2 * t))  # [B]
            mu = gamma * x + torch.sqrt(gamma* (1 - gamma)) * rollout_noise
            return mu, gamma
        
    def discrete_var_bayesian_update(self, t: torch.Tensor, x: torch.Tensor, K: int, rollout_noise: torch.Tensor=None, softmax: bool=True) -> torch.Tensor:
        if rollout_noise is None:
            return MolCRAFT.discrete_var_bayesian_update(self, t, x, K)
        else:
            beta = (self.config.beta1 * (t**2))  # (B,)

            # Eq.(185): p_F(θ|x;t) = E_{N(y | β(t)(Ke_x−1), β(t)KI)} δ (θ − softmax(y))
            # can be sampled by first drawing y ~ N(y | β(t)(Ke_x−1), β(t)KI)
            # then setting θ = softmax(y)
            one_hot_x = x  # (N, K)
            mean = beta * (K * one_hot_x - 1)
            std = (beta * K).sqrt()
            eps = rollout_noise
            y = mean + std * eps
            if softmax:
                theta = F.softmax(y, dim=-1)
            else:
                theta = y
            return theta