import torch
import numpy as np
import torch as th
import wandb
import matplotlib.pyplot as plt

from solvers import SolverBase
from torchvision.utils import save_image

from .base import InpainterBase
from .src_cddb.i2sb.diffusion import Diffusion
from .src_cddb.i2sb import util
from torch_ema import ExponentialMovingAverage 

import sys
sys.path.append('src/inpainters/src_cddb')

import logging
log = logging.getLogger(__name__)

def zero_function(*args, **kwargs):
    return 0

class CDDB_3D(InpainterBase):

    def __init__(
            self, 
            subconfig: dict,
            model_partial: torch.nn.Module, 
            deep: bool, 
            log_trajs: bool,
            solver_cls: SolverBase
            ):
        '''
        Wrapper for CDDB that combines our abstraction with nn.Module
        for convenience (such as moving to proper device). No forward()
        is needed as all inner nn.Modules come with it setup properly and
        we put all further logic inside inpaint().

        deep - whether to use standard CDDB or CDDB deep
        subconfig - contains parameters defined by the authors
        model_partial - partially initialized nn.Module
        '''
        super().__init__()
        self.deep = deep
        self.set_config(subconfig)
        self.setup_diffusion(model_partial)
        # dummy parameter to get device at any moment
        self.device_param = th.nn.Parameter(th.empty(0))
        self.log_trajs = log_trajs

        self.solver_cls = solver_cls

    def set_config(self, subconfig):
        self.config = subconfig

    def setup_diffusion(self, model_partial):
        # setup diffusion object
        interval = self.config.interval
        betas = make_beta_schedule(
            n_timestep=interval, 
            linear_end=self.config.beta_max / interval)
        log_plot(betas, "diffusion_beta_schedule")
        
        betas = np.concatenate([betas[:interval // 2], np.flip(betas[:interval // 2])])
        log_plot(betas, "bridge_beta_schedule")

        self.diffusion = Diffusion(betas, torch.cuda.current_device())
        # self.diffusion = Diffusion(betas, "cpu")
        log.info(f"[Diffusion] Built I2SB diffusion: steps={len(betas)}!")
        # setup model and load checkpoint
        noise_levels = (torch.linspace(self.config.t0, self.config.T, interval) * interval).to(torch.cuda.current_device())
        log_plot(betas, "bridge_noise_levels")
        # self.net = Image256Net(log, noise_levels=noise_levels, use_fp16=opt.use_fp16, cond=opt.cond_x1)
        self.model = model_partial(noise_levels=noise_levels)

        # Load ema network weights
        checkpoint = torch.load(self.config.load, map_location="cpu")
        self.ema = ExponentialMovingAverage(self.model.parameters(), decay=0.9999) # temp ema
        missing = self.ema.load_state_dict(checkpoint['ema'])
        log.info(f'[Net] Checkpoint loading: {missing}')
        self.ema.copy_to(self.model.parameters())

        self.model.eval()
        log.info(f"[Net] Loaded network ckpt: {self.config.load}!")


    def reverse_mask(self, x):
        return 1 - x


    def inpaint(self, x_gt: th.Tensor, x_mask: th.Tensor, start_step_overwrite=None):
        '''
        x_gt - ground truth image with no mask applied
        x_mask - binary mask indicating regions to alter
        '''
        # we need x_gt to be in [-1, 1] range
        x_gt = (x_gt - 0.5) * 2
        assert x_gt.min() < 0. and x_gt.min() >= -1.


        x_gt = x_gt.float()
        x_mask = x_mask.float()
        clean_img = x_gt
        # if len(x_mask.shape) < 4:

        # print(x_mask.shape)
        # bs, ch, h, w, d = x_mask.shape
        # mask = x_mask.repeat(1, ch, 1, 1, 1)
        mask = x_mask

        corrupt_img = clean_img * (1. - mask) + mask
        x1          = clean_img * (1. - mask) + mask * torch.randn_like(clean_img)
        x1_pinv = corrupt_img
        x1_forw = corrupt_img

        cond = x1.detach() if self.config.cond_x1 else None

        assert self.config.interval is not None

        if start_step_overwrite is not None:
            start_step = start_step_overwrite
        else:
            start_step = self.config.start_step or self.config.interval
        assert start_step <= self.config.interval

        # import pdb; pdb.set_trace()

        nfe = self.config.nfe or start_step - 1
        nfe = min(nfe, start_step - 1)
        steps = util.space_indices(start_step, nfe + 1)

        if (start_step is not None) and (start_step != self.config.interval):
            x1 = self.diffusion.q_sample(start_step - 1, x_gt, x1, ot_ode=self.config.ot_ode)
        else:
            x1 = (1. - mask) * x1 + mask * torch.randn_like(x1)

        log_count = 1
        log_count = min(len(steps)-1, log_count)
        log_steps = [steps[i] for i in util.space_indices(len(steps)-1, log_count)]
        assert log_steps[0] == 0
        log.info(f"[CDDB Sampling] steps={start_step}, {nfe=}, {log_steps=}!")

        def corrupt_method(img):
            # img: [-1,1]
            # img[mask==0] = img[mask==0], img[mask==1] = 1 (white)
            return img * (1. - mask) + mask, mask

        def pred_x0_fn(xt, step):
            # step = torch.full((xt.shape[0],), step, device=opt.device, dtype=torch.long)
            step = torch.full((xt.shape[0],), step, dtype=torch.long).to(xt.device)
            out = self.model(xt, step, cond=cond)
            return self.compute_pred_x0(step, xt, out, clip_denoise=self.config.clip_denoise)

        if self.deep:
            log.info("Conditioning with guidance")
            self.solver = self.solver_cls(
                    pred_x0_fn=pred_x0_fn,
                    x1_forw=x1_forw,
                    x_gt=x_gt,
                    mask=mask,
                    corrupt_method=corrupt_method,
                    ot_ode=self.config.ot_ode,
                    cond_fn=self.cond_fn,
                    guidance_classes=None,
                    p_posterior_fn=self.diffusion.p_posterior,
                    step_size=self.config.step_size,
            )
            
            xs, pred_x0s = self.diffusion.ddpm_dps_sampling(
                steps,
                x1,
                self.solver,
                mask=mask,
                ot_ode=self.config.ot_ode,
                log_steps=log_steps,
                verbose=True,
            )
        else:
            xs, pred_x0s = self.diffusion.ddpm_sampling(steps,
                pred_x0_fn,
                x1,
                mask=mask,
                ot_ode=self.config.ot_ode,
                log_steps=log_steps,
                verbose=True
            )
            xs = xs.to(x1.device).to(x1.dtype)

        b, *xdim = x1.shape
        assert xs.shape == pred_x0s.shape == (b, log_count, *xdim)
        x_inp = xs[:, 0, ...]

        if self.log_trajs:
            # wandb.log({'misc/trajectory_xt': wandb.Image((xs[0] + 1) / 2)})
            # wandb.log({'misc/trajectory_pred_x0': wandb.Image((pred_x0s[0] + 1) / 2)})
            save_image(x_mask.squeeze()[:, :, 32].unsqueeze(0), "x_mask.png")
            save_image(x_gt.squeeze()[:, :, 32].unsqueeze(0) + 1.0, "x_gt.png")

        # x_inp comes from [-1, 1] range
        # we scale it to [0, 1]
        x_inp = x_inp.clamp(-1, 1)
        x_inp = (x_inp / 2) + 0.5
        return x_inp
    
    def compute_pred_x0(self, step, xt, net_out, clip_denoise=False):
        """ Given network output, recover x0. This should be the inverse of Eq 12 """
        std_fwd = self.diffusion.get_std_fwd(step, xdim=xt.shape[1:])
        pred_x0 = xt - std_fwd * net_out
        if clip_denoise: pred_x0.clamp_(-1., 1.)
        return pred_x0

    def on_end(self):
        return super().on_end()


def compute_batch(ckpt_opt, out):
    clean_img, y, mask = out
    corrupt_img = clean_img * (1. - mask) + mask
    x1          = clean_img * (1. - mask) + mask * torch.randn_like(clean_img)
    x1_pinv = corrupt_img
    x1_forw = corrupt_img
    cond = x1.detach() if ckpt_opt.cond_x1 else None

    return corrupt_img, x1, mask, cond, y, clean_img, x1_pinv, x1_forw


def make_beta_schedule(n_timestep=1000, linear_start=1e-4, linear_end=2e-2):
    # return np.linspace(linear_start, linear_end, n_timestep)
    betas = (
        torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
    )
    return betas.numpy()


def log_plot(data, name):
    fig = plt.figure()
    plt.plot(data)
    plt.grid()
    wandb.log({f'misc/{name}': wandb.Plotly.make_plot_media(fig)})
    plt.close(fig)
