from open_biomed.models.molecule.molcraft import MolCRAFT
from open_biomed.utils.featurizer import Featurized
from open_biomed.utils.config import Config
from open_biomed.data import Pocket, Molecule
import torch
import torch.nn as nn
from torch_scatter import scatter_mean
from typing import Dict

class CritiqueSBDD(MolCRAFT):
    def __init__(self, model_cfg: Config) -> None:
        super().__init__(model_cfg)
        # For vina score, qed, and sa
        # NOTE: the scores should be normalized to [0, 1]
        self.value_pred = nn.Sequential(
            nn.Linear(model_cfg.hidden_dim, model_cfg.hidden_dim),
            nn.ReLU(),
            nn.Linear(model_cfg.hidden_dim, 3),
            nn.Sigmoid(),
        )

    def interdependency_modeling(self, 
        time,
        protein_pos,  # transform from the orginal BFN codebase
        protein_v,  # transform from
        batch_protein,  # index for protein
        theta_h_t,
        mu_pos_t,
        batch_ligand,  # index for ligand
        gamma_coord,
        return_all=False,  # legacy from targetdiff
        fix_x=False,
    ):
        """
        Compute output distribution parameters for p_O (x' | θ; t) (x_hat or k^(d) logits).
        Draw output_sample = x' ~ p_O (x' | θ; t).
            continuous x ~ δ(x - x_hat(θ, t))
            discrete k^(d) ~ softmax(Ψ^(d)(θ, t))_k
        Args:
            time: [node_num x batch_size, 1] := [N_ligand, 1]
            protein_pos: [node_num x batch_size, 3] := [N_protein, 3]
            protein_v: [node_num x batch_size, protein_atom_feature_dim] := [N_protein, 27]
            batch_protein: [node_num x batch_size] := [N_protein]
            theta_h_t: [node_num x batch_size, atom_type] := [N_ligand, 13]
            mu_pos_t: [node_num x batch_size, 3] := [N_ligand, 3]
            batch_ligand: [node_num x batch_size] := [N_ligand]
            gamma_coord: [node_num x batch_size, 1] := [N_ligand, 1]
        """
        theta_h_t = 2 * theta_h_t - 1  # from 1/K \in [0,1] to 2/K-1 \in [-1,1]

        # ---------for targetdiff-----------
        init_ligand_v = theta_h_t
        # time embedding
        time_emb = self.time_emb_layer(time)
        input_ligand_feat = torch.cat([init_ligand_v, time_emb], -1)

        h_protein = self.protein_atom_emb(protein_v)  # [N_protein, self.hidden_dim - 1]
        init_ligand_h = self.ligand_atom_emb(input_ligand_feat)  # [N_ligand, self.hidden_dim - 1]

        if self.node_indicator:
            h_protein = torch.cat(
                [h_protein, torch.zeros(len(h_protein), 1).to(h_protein)], -1
            )  # [N_ligand, self.hidden_dim]
            init_ligand_h = torch.cat(
                [init_ligand_h, torch.ones(len(init_ligand_h), 1).to(h_protein)], -1
            )  # [N_ligand, self.hidden_dim]

        h_all, pos_all, batch_all, mask_ligand = self.compose_context(
            h_protein=h_protein,
            h_ligand=init_ligand_h,
            pos_protein=protein_pos,
            pos_ligand=mu_pos_t,
            batch_protein=batch_protein,
            batch_ligand=batch_ligand,
        )
        # get the context for the protein and ligand, while the ligand is h is noisy (h_t)/ pos is also the noise version. (pos_t)

        # time = 2 * time - 1
        outputs = self.unio2net(
            h_all, pos_all, mask_ligand, batch_all, return_all=return_all, fix_x=fix_x
        )
        final_pos, final_h = outputs["x"], outputs["h"]
        final_ligand_pos, final_ligand_h = final_pos[mask_ligand], final_h[mask_ligand]
        final_ligand_v = self.v_inference(final_ligand_h)  # [N_ligand, 13] 
        return final_ligand_pos, final_ligand_v, final_ligand_h

    def forward(self, pocket: Featurized[Pocket], molecule: Featurized[Molecule], label: torch.Tensor, t: torch.Tensor) -> Dict[str, torch.Tensor]:
        # Add noise to molecule
        # import pdb; pdb.set_trace()
        batch_size = molecule["pos_batch"].max() + 1
        mu_coord, gamma_coord = self.continuous_var_bayesian_update(t, molecule["pos"])
        theta_h_t = self.discrete_var_bayesian_update(t, molecule["atom_feature"], self.config.ligand_atom_feature_dim)
        
        mu_pos, theta_h, atom_h = self.interdependency_modeling(
            time=t,
            protein_pos=pocket["pos"],
            protein_v=pocket["atom_feature"],
            batch_protein=pocket["pos_batch"],
            theta_h_t=theta_h_t,
            mu_pos_t=mu_coord,
            batch_ligand=molecule["pos_batch"],
            gamma_coord=gamma_coord,
        )

        # Predict the value
        value = self.value_pred(atom_h)
        value = scatter_mean(value, molecule["pos_batch"], dim=0)
        # print(value - label)
        return {
            "loss": torch.abs(value.squeeze() - label),
            "loss_affinity": torch.abs(value[:, 0] - label[:, 0]),
            "loss_qed": torch.abs(value[:, 1] - label[:, 1]),
            "loss_sa": torch.abs(value[:, 2] - label[:, 2]),
        }

    def predict(self, pocket: Featurized[Pocket], molecule: Featurized[Molecule], t: torch.Tensor) -> torch.Tensor:
        gamma_coord = 1 - torch.pow(self.sigma1_coord, 2 * t)
        _, _, atom_h = self.interdependency_modeling(
            time=t,
            protein_pos=pocket["pos"],
            protein_v=pocket["atom_feature"],
            batch_protein=pocket["pos_batch"],
            theta_h_t=molecule["theta_h"],
            mu_pos_t=molecule["mu_pos"],
            batch_ligand=molecule["mu_pos_batch"],
            gamma_coord=gamma_coord,
        )
        value = self.value_pred(atom_h)
        value = scatter_mean(value, molecule["mu_pos_batch"], dim=0)
        return value
