import lpips
import torch
import torch.nn.functional as F

from .base import Guidance
from classifiers.base import ClassifierBase

class RePaintTargetClassifierLPIPSGuidance(Guidance):

    def __init__(
            self, 
            model: ClassifierBase, 
            uses_target_clf: bool, 
            clf_scale: float,
            lpips_scale: float,
            rescale_lpips: bool,
            lpips_net: str,
            log: bool):
        '''
        rescale_lpips - boolean indicating whether to apply inverse sigmoid
                        to lpips before computing gradients
        '''
        super(RePaintTargetClassifierLPIPSGuidance, self).__init__()
        
        self.clf = model
        self.clf_scale = clf_scale
        self.uses_target_clf = uses_target_clf

        self.lpips = lpips.LPIPS(net = lpips_net) if lpips_scale != 0. else None
        self.lpips_scale = lpips_scale
        self.rescale_lpips = rescale_lpips

        self.log = log
    
    def get_cond_module(self):
        if self.lpips is not None:
            return self.clf, self.lpips
        else:
            return self.clf
    
    def get_cond_fn(self):

        def cond_fn(x, t, y = None, gt = None, **kwargs):
            assert y is not None

            with torch.enable_grad():
                x_in = kwargs['pred_xstart'].detach().requires_grad_(True)

                clf_logits = self.clf(x_in)
                clf_log_probs = F.log_softmax(clf_logits, dim = -1)
                clf_log_probs = clf_log_probs.gather(dim = 1, index = y[:, None]).flatten()

                grad_clf = torch.autograd.grad(clf_log_probs.sum(), x_in)[0]

                if self.lpips_scale != 0.:
                    lpips_vals = self.lpips((x_in - 0.5) * 2, gt)

                    if self.rescale_lpips:
                        lpips_vals = (lpips_vals / (1 - lpips_vals)).log()

                    grad_lpips = torch.autograd.grad(lpips_vals.sum(), x_in)[0]

                else:
                    grad_lpips = torch.zeros_like(grad_clf)

                # we aim to increase clf_log_probs while decreasing
                # lpips_vals, so we use '-' here
                grad = self.clf_scale * grad_clf - self.lpips_scale * grad_lpips

                if self.log:
                    self.log_grad_info(t, grad_clf, 'clf')
                    self.log_grad_info(t, grad_lpips, 'lpips')

                return grad
            
        cond_fn.require_grad = True

        return cond_fn