import lpips
import torch
import torch.nn.functional as F

from .base import Guidance
from .utils import ADAMGradientStabilization, l2_normalize_gradient
from classifiers.base import ClassifierBase

class RePaintTargetClassifierFromTweedie(Guidance):

    def __init__(
            self, 
            model: ClassifierBase, 
            uses_target_clf: bool, 
            scale: float,
            normalize: bool,
            stabilization: ADAMGradientStabilization,
            require_grad: bool,
            keep_graph: bool,
            log: bool):
        '''
        keep_graph - whether to keep the computational graph after backward pass
        require_grad - whether guidance requires computational graph of input
        '''
        super(RePaintTargetClassifierFromTweedie, self).__init__()
        
        self.clf = model
        self.scale = scale
        self.uses_target_clf = uses_target_clf
        self.normalize = normalize
        self.stabilization = stabilization
        self.require_grad = require_grad
        self.keep_graph = keep_graph
        self.log = log
    
    def get_cond_module(self):
        return self.clf
    
    def get_cond_fn(self):

        def cond_fn(x, t, y = None, gt = None, **kwargs):
            assert y is not None

            with torch.enable_grad():

                x_in = kwargs["pred_xstart"]

                clf_logits = self.clf(x_in)
                clf_log_probs = F.log_softmax(clf_logits, dim = -1)
                clf_log_probs = clf_log_probs.gather(dim = 1, index = y[:, None]).flatten()

                grad_clf = torch.autograd.grad(clf_log_probs.sum(), x, create_graph=cond_fn.keep_graph)[0]

                with torch.no_grad():
                    
                    if self.stabilization is not None:
                        grad_clf = self.stabilization(grad_clf)

                    if self.normalize:
                        grad_clf = l2_normalize_gradient(grad_clf)

                if self.log:
                    self.log_grad_info(t, grad_clf, 'clf')

                return self.scale * grad_clf
            
        cond_fn.keep_graph = self.keep_graph
        cond_fn.require_grad = self.require_grad

        return cond_fn