import torch

from .src_mcg.guided_diffusion import dist_util, logger
from .src_mcg.guided_diffusion.script_util import create_gaussian_diffusion
from .src_mcg.guided_diffusion.unet import UNetModel

from .base import InpainterBase
from guidance import Guidance

class MCG(InpainterBase):

    def __init__(self, subconfig: dict, guidance: Guidance):
        '''
        Wrapper for MCG that combines our abstraction with nn.Module
        for convenience (such as moving to proper device). No forward()
        is needed as all inner nn.Modules come with it setup properly and
        we put all further logic inside inpaint().
        '''
        super().__init__()
        self.set_config(subconfig)
        self.setup_diffusion()
        self.setup_guidance(guidance)


    def set_config(self, config):
        self.config = config


    def setup_diffusion(self):
        # create model, load ckpt
        model = UNetModel(**self.config.unet)
        model.load_state_dict(torch.load(self.config.model_path, map_location="cpu"))
        
        if self.config.use_fp16:
            model.convert_to_fp16()

        model.eval()

        # create diffusion
        diffusion = create_gaussian_diffusion(**self.config.diffusion)

        # set model and diffusion
        self.model = model
        self.diffusion = diffusion


    def setup_guidance(self, guidance):
        if guidance is not None:
            self.classifier = guidance.get_cond_module()
            self.cond_fn = guidance.get_cond_fn()
        else:
            self.classifier = None
            self.cond_fn = None


    def forward(self, x):
        pass


    def reverse_mask(self, x):
        return 1 - x


    def inpaint(self, x_gt: torch.Tensor, x_mask: torch.Tensor, guidance_classes: torch.Tensor):
        '''
        x_gt - ground truth image with no mask applied
        x_mask - binary mask indicating regions to alter
        '''
        # we need x_gt to be in [-1, 1] range
        x_gt = (x_gt - 0.5) * 2
        assert x_gt.min() < 0. and x_gt.min() >= -1.

        x_mask = self.reverse_mask(x_mask)
        x_mask = x_mask.unsqueeze(1).repeat_interleave(3, 1)

        b, c, h, w = x_mask.shape

        model_kwargs = {}
        model_kwargs['ref_img'] = x_gt

        # TODO: try ddim
        x_inp = self.diffusion.p_sample_loop(
            self.model,
            (b, c, h, w),
            clip_denoised=self.config.clip_denoised,
            model_kwargs=model_kwargs,
            mask=x_mask,
            sample_method=self.config.sample_method,
            progress=True,
            cond_fn=self.cond_fn,
            y=guidance_classes,
            step_size=self.config.step_size,
            gt=x_gt,
        )

        # x_inp comes from [-1, 1] range
        # we scale it to [0, 1]
        x_inp = x_inp.clamp(-1, 1)
        x_inp = (x_inp / 2) + 0.5
        return x_inp