import torch
import torch.nn.functional as F

from .base import Guidance
from inpainters.src_repaint.guided_diffusion.script_util import (
    create_classifier,
    select_args,
    classifier_defaults
)
from inpainters.src_repaint.guided_diffusion import dist_util
from inpainters.src_repaint.conf_mgt.conf_base import Default_Conf

class RePaintClassifierGuidance(Guidance):

    def __init__(self, subconfig: dict, uses_target_clf: bool, log: bool):
        super(RePaintClassifierGuidance, self).__init__()

        self.set_config(subconfig)
        self.clf = self.init_clf()
        self.uses_target_clf = uses_target_clf
        self.log = log

    def set_config(self, subconfig):
        conf_arg = Default_Conf()
        conf_arg.update(subconfig)
        self.config = conf_arg

    def init_clf(self):
        '''
        Ported from RePaint to ensure compatibility.
        '''
        # init classifier and load ckpt
        classifier = create_classifier(
            **select_args(self.config, classifier_defaults().keys()))
        classifier.load_state_dict(
            dist_util.load_state_dict(self.config.classifier_path))

        if self.config.classifier_use_fp16:
            classifier.convert_to_fp16()

        # set eval to disable dropout etc
        classifier.eval()

        return classifier
    
    def get_cond_module(self):
        return self.clf
    
    def get_cond_fn(self):

        def cond_fn(x, t, y = None, gt = None, **kwargs):
            assert y is not None
            with torch.enable_grad():
                x_in = x.detach().requires_grad_(True)
                logits = self.clf(x_in, t)
                log_probs = F.log_softmax(logits, dim=-1)
                selected = log_probs[range(len(logits)), y.view(-1)]
                guidance_scale = self.config.classifier_scale
                grad = torch.autograd.grad(selected.sum(), x_in)[0]

                if self.log:
                    self.log_grad_info(t, grad, 'clf')

                return guidance_scale * grad

        return cond_fn