# Copyright (c) 2022 Huawei Technologies Co., Ltd.
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
#
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# This repository was forked from https://github.com/openai/guided-diffusion, which is under the MIT license

import torch as th

import sys
sys.path.append('src/inpainters/src_repaint')

from .src_repaint.conf_mgt.conf_base import Default_Conf
from .src_repaint.guided_diffusion import dist_util

# Workaround
try:
    import ctypes
    libgcc_s = ctypes.CDLL('libgcc_s.so.1')
except:
    pass

from .src_repaint.guided_diffusion.script_util import (
    model_and_diffusion_defaults,
    create_model_and_diffusion,
    select_args,
)  # noqa: E402

from .base import InpainterBase
from guidance import Guidance

def toU8(sample):
    if sample is None:
        return sample

    sample = ((sample + 1) * 127.5).clamp(0, 255).to(th.uint8)
    sample = sample.permute(0, 2, 3, 1)
    sample = sample.contiguous()
    sample = sample.detach().cpu().numpy()
    return sample

class RePaint(InpainterBase):

    def __init__(self, subconfig: dict, guidance: Guidance):
        '''
        Wrapper for RePaint that combines our abstraction with nn.Module
        for convenience (such as moving to proper device). No forward()
        is needed as all inner nn.Modules come with it setup properly and
        we put all further logic inside inpaint().

        subconfig - dict with parameters that follow the original config.
        guidance - Guidance object which provides cond_fn and classifier
            to use during sampling
        '''
        super().__init__()

        self.set_config(subconfig)
        self.setup_diffusion()
        self.setup_guidance(guidance)
        # dummy parameter to get device at any moment
        self.device_param = th.nn.Parameter(th.empty(0))

    def set_config(self, subconfig):
        conf_arg = Default_Conf()
        conf_arg.update(subconfig.params)
        self.config = conf_arg

    def setup_diffusion(self):
        # create model and diffusion, load ckpt
        model, diffusion = create_model_and_diffusion(
            **select_args(self.config, model_and_diffusion_defaults().keys()), 
            conf = self.config)

        model.load_state_dict(
            dist_util.load_state_dict(self.config.model_path), strict=False)

        if self.config.use_fp16:
            model.convert_to_fp16()

        model.eval()

        # set sample_fn
        self.sample_fn = (
            diffusion.p_sample_loop if not self.config.use_ddim else diffusion.ddim_sample_loop
        )

        def model_fn(x, t, y = None, gt = None, **kwargs):
            assert y is not None
            return model(x, t, y if self.config.class_cond else None, gt=gt)
        
        # set model_fn
        self.model = model
        self.model_fn = model_fn

    def setup_guidance(self, guidance):
        
        if guidance is not None:
            self.classifier = guidance.get_cond_module()
            self.cond_fn = guidance.get_cond_fn()

        else:
            self.classifier = None
            self.cond_fn = None


    def forward(self, x):
        pass

    def reverse_mask(self, x):
        return 1 - x

    def inpaint(self, x_gt: th.Tensor, x_mask: th.Tensor, guidance_classes: th.Tensor):
        '''
        x_gt - ground truth image with no mask applied
        x_mask - binary mask indicating regions to alter
        '''
        # we need x_gt to be in [-1, 1] range
        x_gt = (x_gt - 0.5) * 2
        assert x_gt.min() < 0. and x_gt.min() >= -1.

        # we provide masks as tensors with channel dimension
        # so we add it here by unsqueezing and repeating
        x_keep_mask = self.reverse_mask(x_mask)
        x_keep_mask = x_keep_mask.unsqueeze(1).repeat_interleave(3, 1)

        model_kwargs = {}
        model_kwargs['gt'] = x_gt
        model_kwargs['gt_keep_mask'] = x_keep_mask

        batch_size = model_kwargs['gt'].shape[0]
        device = self.device_param.device

        model_kwargs["y"] = guidance_classes
        
        result = self.sample_fn(
            self.model_fn,
            (batch_size, 3, self.config.image_size, self.config.image_size),
            clip_denoised = self.config.clip_denoised,
            model_kwargs = model_kwargs,
            cond_fn = self.cond_fn,
            device = device,
            progress = self.config.show_progress,
            return_all = True,
            conf = self.config)
        
        # extracted inpainted samples
        x_inp = result['sample']

        # x_inp comes from [-1, 1] range
        # we scale it to [0, 1]
        x_inp = (x_inp / 2) + 0.5
        return x_inp