import sys

sys.path.append('src/inpainters/src_k_diffusion_inverse_problems')

import torch
from tqdm import tqdm
from functools import partial
from .src_k_diffusion_inverse_problems.guided_diffusion_k_diffusion import dist_util
from .src_k_diffusion_inverse_problems.guided_diffusion_k_diffusion.script_util import (
    model_and_diffusion_defaults,
    create_model_and_diffusion,
    args_to_dict
)

from .src_k_diffusion_inverse_problems import src_k_diffusion as K
from .src_k_diffusion_inverse_problems.condition.measurements import get_operator
from .src_k_diffusion_inverse_problems.condition.condition import ConditionOpenAIDenoiser

from .base import InpainterBase
from guidance import Guidance

import math

class KDiffusion(InpainterBase):

    def __init__(self, subconfig: dict, guidance: Guidance):
        super().__init__()

        # dummy parameter to get device at any moment
        self.device_param = torch.nn.Parameter(torch.empty(0))
        self.set_config(subconfig)
        self.setup_diffusion()
        self.setup_guidance(guidance)

    def set_config(self, subconfig):
        self.config = subconfig['config']
        self.args = subconfig['args']
        self.model_config = self.config['model']
        self.recon_mse = torch.load(self.model_config['recon_mse'])

    def setup_diffusion(self):
        # create model and diffusion, load ckpt
        self.inner_model, self.diffusion = create_model_and_diffusion(
        **args_to_dict(self.model_config['openai'], model_and_diffusion_defaults().keys())) 

        self.inner_model.load_state_dict(
            dist_util.load_state_dict(self.args['checkpoint'], map_location="cpu")
        )

        device = self.device_param.device
        self.inner_model = self.inner_model.eval().to(device)
        
    def setup_guidance(self, guidance):
        
        if guidance is not None:
            self.classifier = guidance.get_cond_module()
            self.cond_fn = guidance.get_cond_fn()

        else:
            self.classifier = None
            self.cond_fn = None

    def reverse_mask(self, x):
        return 1 - x

    def inpaint(self, x_gt: torch.Tensor, x_mask: torch.Tensor, guidance_classes: torch.Tensor):


        # we need x_gt to be in [-1, 1] range
        x_gt = (x_gt - 0.5) * 2
        assert x_gt.min() < 0. and x_gt.min() >= -1.

        # we provide masks as tensors with channel dimension
        # so we add it here by unsqueezing and repeating
        x_keep_mask = self.reverse_mask(x_mask)
        # x_keep_mask = x_mask
        x_keep_mask = x_keep_mask.unsqueeze(1).repeat_interleave(3, 1)

        device = self.device_param.device

        sigma_min = self.model_config['sigma_min']
        sigma_max = self.model_config['sigma_max']

        sigmas = K.sampling.get_sigmas_karras(self.args['steps'], sigma_min, sigma_max, rho=7., device=device)

        inpaints = []
        INNER_BATCH_SIZE = 2

        for batch_start, batch_end in tqdm(zip(range(0, x_gt.shape[0], INNER_BATCH_SIZE), range(INNER_BATCH_SIZE, max(x_gt.shape[0] + 1, INNER_BATCH_SIZE+1), INNER_BATCH_SIZE)), total=math.ceil(x_gt.shape[0] // INNER_BATCH_SIZE)):
            x_gt_curr = x_gt[batch_start:batch_end]
            x_keep_mask_curr = x_keep_mask[batch_start:batch_end]
            guidance_classes_curr = guidance_classes[batch_start:batch_end]

            def model_fn(x, t, y = None, **kwargs):
                # y = guidance_classes_curr
                # assert y is not None
                return self.inner_model(x, t, y if self.config.class_cond else None)

            self.model_fn = model_fn

            model_kwargs = {}
            model_kwargs['gt'] = x_gt_curr
            model_kwargs['gt_keep_mask'] = x_keep_mask_curr
            model_kwargs["y"] = guidance_classes_curr

            operator = get_operator(device=device, **self.args['operator_config'], mask=x_keep_mask_curr)

            measurement = operator.forward(x_gt_curr.clone(), flatten=True, noiseless=True)
            model = ConditionOpenAIDenoiser(
                inner_model=self.model_fn,
                diffusion=self.diffusion,
                operator=operator,
                measurement=measurement,
                guidance=self.args['guidance'],
                x0_cov_type=self.args['xstart_cov_type'],
                recon_mse=self.recon_mse,
                lambda_=self.args['lam'],
                zeta=self.args['zeta'],
                eta=self.args['eta'],
                num_hutchinson_samples=self.args['num_hutchinson_samples'],
                mle_sigma_thres=self.args['mle_sigma_thres'],
                device=device,
                cond_fn=self.cond_fn,
                model_kwargs=model_kwargs,
            ).eval()

            for param in model.parameters():
                param.requires_grad = False

            def sample_fn(n):
                size = self.model_config['input_size']
                x = torch.randn([n, self.model_config['input_channels'], size[0], size[1]], device=device) * sigma_max
                sampler = partial(K.sampling.sample_heun if not self.args['euler'] else K.sampling.sample_euler,
                                    model, x, sigmas, disable=False)
                if not self.args.ode:
                    x_0 = sampler(s_churn=80, s_tmin=0.05, s_tmax=50, s_noise=1.003)
                else:
                    x_0 = sampler()     
                return x_0
            
            x_inp = sample_fn(x_gt_curr.shape[0]).detach().clone()

            inpaints.append(x_inp)

        x_inp = torch.cat(inpaints, dim=0)
        # x_inp comes from [-1, 1] range
        # we scale it to [0, 1]
        x_inp = (x_inp / 2) + 0.5
        return x_inp