import wandb
import lpips
import torch
import torch.nn.functional as F

from .base import Guidance
from classifiers.base import ClassifierBase

class RePaintBBTargetClassifierLPIPSGuidance(Guidance):

    def __init__(
            self, 
            model: ClassifierBase, 
            uses_target_clf: bool, 
            clf_scale: float,
            lpips_scale: float,
            rescale_lpips: bool,
            lpips_net: str,
            ignore_inc_t_steps: bool,
            log: bool):
        '''
        rescale_lpips - boolean indicating whether to apply inverse sigmoid
                        to lpips before computing gradients
        '''
        super(RePaintBBTargetClassifierLPIPSGuidance, self).__init__()
        
        self.clf = model
        self.clf_scale = clf_scale
        self.uses_target_clf = uses_target_clf

        self.lpips = lpips.LPIPS(net = lpips_net) if lpips_scale != 0. else None
        self.lpips_scale = lpips_scale
        self.rescale_lpips = rescale_lpips
        self.ignore_inc_t_steps = ignore_inc_t_steps
        self.log = log
    
    def get_cond_module(self):
        if self.lpips is not None:
            return self.clf, self.lpips
        else:
            return self.clf
    
    def get_clf_log_probs(self, x, y):
        clf_logits = self.clf(x)
        clf_log_probs = F.log_softmax(clf_logits, dim = -1)
        clf_log_probs = clf_log_probs.gather(dim = 1, index = y[:, None]).flatten()
        return clf_log_probs

    def get_cond_fn(self):

        def cond_fn(x, t, y = None, gt = None, **kwargs):
            assert y is not None
            
            skip = self.ignore_inc_t_steps and (kwargs['t_step'] > kwargs['t_step_prev']).all()

            if kwargs['pred_xstart_prev'] is None or skip:
                grad = torch.zeros_like(kwargs['pred_xstart'])

            else:
                # black box gradient estimate
                with torch.no_grad():
                    x_in_prev = kwargs['pred_xstart_prev'].detach()
                    x_in = kwargs['pred_xstart'].detach()
                    
                    clf_log_probs_prev = self.get_clf_log_probs(x_in_prev, y)
                    clf_log_probs = self.get_clf_log_probs(x_in, y)

                    df = (clf_log_probs - clf_log_probs_prev).view(-1, 1, 1, 1)
                    dx = x_in - x_in_prev
                    grad_clf = df / dx
                    grad_clf[dx == 0] = 0.
                    # NOTE: grad_clf has enormous scale. as a tmp fix we scale it to [-1, 1].
                    #       this needs attention as it should be rather in a similar scale
                    #       to scores.
                    grad_clf -= grad_clf.min()
                    grad_clf /= grad_clf.max()
                    grad_clf = (grad_clf - 0.5) * 2

                # lpips true gradient
                if self.lpips_scale != 0.:
                    with torch.enable_grad():
                        x_in = x_in.requires_grad_(True)

                        lpips_vals = self.lpips((x_in - 0.5) * 2, gt)

                        if self.rescale_lpips:
                            lpips_vals = (lpips_vals / (1 - lpips_vals)).log()

                        grad_lpips = torch.autograd.grad(lpips_vals.sum(), x_in)[0]

                else:
                    grad_lpips = torch.zeros_like(grad_clf)

                grad = self.clf_scale * grad_clf - self.lpips_scale * grad_lpips

                if self.log:
                    self.log_grad_info(t, grad_clf, 'clf')
                    self.log_grad_info(t, grad_lpips, 'lpips')

            return grad
            
        return cond_fn