import logging

import torch as th
from omegaconf import OmegaConf

from guidance import Guidance
from .base import InpainterBase
from .src_ddrm.datasets import data_transform, inverse_data_transform
from .src_ddrm.diffusion import get_beta_schedule
from .src_ddrm.functions.denoising import efficient_generalized_steps, compute_alpha
from .src_ddrm.functions.svd_replacement import Inpainting
from .src_ddrm.guided_diffusion.script_util import create_model, load_state_dict
from .src_ddrm.guided_diffusion.unet import UNetModel

log = logging.getLogger(__name__)



class DDRM(InpainterBase):

    def __init__(self, subconfig: dict, guidance: Guidance):
        super().__init__()
        self.config = subconfig.params
        # dummy parameter to get device at any moment
        self.device_param = th.nn.Parameter(th.empty(0))

        self.setup_guidance(guidance)
        self.setup_diffusion_model()
        self.setup_diffusion_betas()

    def setup_diffusion_model(self):
        config_dict = OmegaConf.to_container(self.config.model)
        model: UNetModel = create_model(**config_dict)
        model.load_state_dict(load_state_dict(self.config.model_path), strict=False)

        if self.config.model.use_fp16:
            model.convert_to_fp16()

        model.eval()
        self.model = model

    def setup_guidance(self, guidance):
        if guidance is not None:
            self.classifier = guidance.get_cond_module()
            self.cond_fn = guidance.get_cond_fn()
        else:
            self.classifier = None
            self.cond_fn = None

    def setup_diffusion_betas(self):
        betas = get_beta_schedule(
            beta_schedule=self.config.diffusion.beta_schedule,
            beta_start=self.config.diffusion.beta_start,
            beta_end=self.config.diffusion.beta_end,
            num_diffusion_timesteps=self.config.diffusion.num_diffusion_timesteps,
        )
        betas = th.from_numpy(betas).float()
        self.betas = th.nn.Parameter(betas)

    def forward(self, x):
        pass

    def get_H_functions(self, mask: th.Tensor):
        """Expects a mask of shape: [channels x img_size x img_size]"""
        mask = mask.reshape(-1)
        missing_r = th.nonzero(mask == 0).long().reshape(-1) * 3
        missing_g = missing_r + 1
        missing_b = missing_g + 1
        missing = th.cat([missing_r, missing_g, missing_b], dim=0)
        H_funcs = Inpainting(self.config.data.channels, self.config.data.image_size, missing)
        return H_funcs

    def get_timesteps(self, num_timesteps = None):
        if num_timesteps is None:
            num_timesteps = self.betas.shape[0]

        steps = th.linspace(0, num_timesteps - 1, self.config.timesteps).int()
        return steps.tolist()

    def get_ddrm_mask(self, x_mask: th.Tensor):
        if len(x_mask.shape) == 3:
            log.warn("Assuming that every mask in the batch is exactly the same")
            mask = x_mask[0, :, :]
        elif len(x_mask.shape) == 2:
            mask = x_mask
        else:
            raise ValueError(f"Invalid shape for a mask. Expected 2 or 3 dimensional mask, but got: {x_mask.shape}")
        mask = (1 - mask)
        return mask

    def inpaint(self, x_gt: th.Tensor, x_mask: th.Tensor, guidance_classes: th.Tensor):
        """
        x_gt - ground truth image with no mask applied [bs x channels x img_size x img_size]
        x_mask - binary mask indicating regions to alter [bs x img_size x img_size]
        """
        mask = self.get_ddrm_mask(x_mask)
        H_funcs = self.get_H_functions(mask)

        x_orig = data_transform(self.config, x_gt)
        y_0 = H_funcs.H(x_orig)
        y_0 = y_0 + self.config.sigma0 * th.randn_like(y_0)

        ##Begin DDIM
        noise = th.randn(
            y_0.shape[0],
            self.config.data.channels,
            self.config.data.image_size,
            self.config.data.image_size,
            device=self.device_param.device,
        )

        if self.config.start_step < self.config.diffusion.num_diffusion_timesteps:
            n = noise.shape[0]
            t = (th.ones(n) * self.config.start_step)
            alfa = compute_alpha(self.betas, t.int())
            x = alfa.sqrt() * x_orig + (1 - alfa).sqrt() * noise
            timesteps = self.get_timesteps(num_timesteps=self.config.start_step)
        else:
            x = noise
            timesteps = self.get_timesteps()

        with th.no_grad():
            x = efficient_generalized_steps(
                x, timesteps, self.model, self.betas, H_funcs, y_0, \
                self.config.sigma0, etaA=self.config.etaA, etaB=self.config.etaB, etaC=self.config.etaC, \
                x_gt = x_orig, cls_fn=None, classes=guidance_classes, cond_fn=self.cond_fn
            )
        x = [inverse_data_transform(self.config, y) for y in x]
        return th.stack(x)
        

    
