import torch
from tqdm import tqdm
from .base import Algo
from utils.scheduler import Scheduler
import numpy as np

GRAD_INTERFACE = True


class LGD(Algo):
    def __init__(self,
                 net,
                 forward_op,
                 diffusion_scheduler_config,
                 guidance_scale,
                 num_samples=10,
                 sde=True):
        super(LGD, self).__init__(net, forward_op)
        self.scale = guidance_scale
        self.diffusion_scheduler_config = diffusion_scheduler_config
        self.scheduler = Scheduler(**diffusion_scheduler_config)
        self.sde = sde
        self.num_samples = num_samples

    @torch.no_grad()
    def inference(self, observation, num_samples=1, **kwargs):
        device = self.forward_op.device
        x_initial = torch.randn(num_samples, self.net.img_channels, self.net.img_resolution,
                                self.net.img_resolution, device=device) * self.scheduler.sigma_max
        if num_samples > 1:
            observation = observation.repeat(num_samples, 1, 1, 1)
        x_next = x_initial
        x_next.requires_grad = True
        pbar = tqdm(range(self.scheduler.num_steps))

        observation = observation[None].repeat(self.num_samples, *[1] * len(observation.shape)).flatten(0, 1)
        for i in pbar:
            x_cur = x_next.detach().requires_grad_(True)

            sigma, factor, scaling_factor = self.scheduler.sigma_steps[i], self.scheduler.factor_steps[i], \
                self.scheduler.scaling_factor[i]
            rt = sigma / np.sqrt(1 + sigma ** 2)
            with torch.enable_grad():
                denoised = self.net(x_cur / self.scheduler.scaling_steps[i], torch.as_tensor(sigma).to(x_cur.device))
                denoised_shape = denoised.shape
                mu = denoised[None].repeat(self.num_samples, *[1]*len(denoised_shape)).flatten(0, 1)
                samples = mu + torch.randn_like(mu) * rt

                if GRAD_INTERFACE:
                    gradient, loss = self.forward_op.gradient(samples, observation, return_loss=True)
                    gradient = gradient.view(self.num_samples, *denoised_shape).mean(0)
                    ll_grad = torch.autograd.grad(denoised, x_cur, gradient)[0]
                    ll_grad = ll_grad * 0.5 / np.sqrt(loss)
                else:
                    loss = self.forward_op.loss(samples, observation)
                    loss = loss.view(self.num_samples, *denoised_shape).mean(0).sum()
                    ll_grad = torch.autograd.grad(loss, x_cur)[0]
                    ll_grad = ll_grad * 0.5 / np.sqrt(loss)

            score = (denoised - x_cur / self.scheduler.scaling_steps[i]) / sigma ** 2 / self.scheduler.scaling_steps[i]
            # pbar.set_description(f'Iteration {i + 1}/{self.scheduler.num_steps}. Avg. Error: {difference.abs().mean().item()}')

            if self.sde:
                epsilon = torch.randn_like(x_cur)
                x_next = x_cur * scaling_factor + factor * score + np.sqrt(factor) * epsilon
            else:
                x_next = x_cur * scaling_factor + factor * score * 0.5
            x_next -= ll_grad * self.scale

            # coef = self.scale / norm

            # factor = 2 * t_cur * (t_cur - t_next)
            # score = (denoised - x_cur) / t_cur ** 2
            # if self.sde:
            #     epsilon = torch.randn_like(x_cur)
            #     x_next = x_cur + factor * score + torch.sqrt(factor) * epsilon
            # else:
            #     x_next = x_cur + factor * score * 0.5
            # x_next -= ll_grad * self.scale
            # with torch.enable_grad():
            #     x_next = x_cur + (t_next - t_cur) * d_cur
            #     x_next += (t_next - t_cur) * d_ll * coef

        return x_next
