import torch
import torch.nn as nn
import pdb


class ValueGuide(nn.Module):

    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, points, t, xt, epsilon, atbar, edge_index, diffusion_type):
        
        #V1
        if diffusion_type == "classifier_v1":
            reward = self.model(points, t, xt.float(), edge_index)
        #V2
        elif diffusion_type == "classifier_v2":
            atbar = torch.tensor(atbar).to(xt.device).reshape(-1)
            x0_hat = (xt - torch.sqrt(1.0 - atbar) * epsilon) / torch.sqrt(atbar)
            # print(x0_hat.shape)
            reward = self.model(points, t, x0_hat.float(), edge_index)
        else:
            raise NotImplementedError
        # for version1
        
        return reward.squeeze(dim=-1)
        
    def gradients(self, points, t, xt, pred, atbar, edge_index, diffusion_type):
        
        xt.requires_grad_()
        pred = pred.squeeze(1)
        
        y = self(points, t, xt , pred, atbar,edge_index, diffusion_type)
        grad = torch.autograd.grad([(y.sum())], [xt])[0]
        xt.detach()
        return y, grad