import torch

from model.pipeline import *


class Score_Estimator:
    def __init__(
        self,
        pipe: SimpleDiffusionPipeline,
        time_step=0,
        grad_weight_type="uniform",
        grad_batch_size=10,
    ):
        self.pipe = pipe
        self.device = str(pipe.device)
        self.time_step = time_step
        self.grad_weight_type = grad_weight_type
        self.grad_batch_size = grad_batch_size
        self.embed_neg = pipe.prompt2embed(
            "A doubling image, unrealistic, artifacts, distortions, unnatural blending, ghosting effects,\
            overlapping edges, harsh transitions, motion blur, poor resolution, low detail"
        )

    def grad_weight(self, t):
        if self.grad_weight_type == "uniform":
            return 1
        if self.grad_weight_type == "decrease":
            return 1 / (1 - self.pipe.scheduler.alphas_cumprod[t]) ** 0.5

    def grad_compute(self, latent, embed_cond):
        assert latent.shape[0] == embed_cond.shape[0]
        b = latent.shape[0]
        embed_neg = self.embed_neg.repeat(b, 1, 1)
        grad_c, grad_d = 0, 0
        with torch.autocast(device_type=self.device, dtype=torch.float16):
            grad_c = -self.pipe.noise_pred(latent, self.time_step, embed_cond)
            grad_d = self.pipe.noise_pred(latent, self.time_step, embed_neg)
        w = self.grad_weight(self.time_step)
        grad = 0.5 * w * (grad_c + grad_d)
        return grad

    def grad_compute_batch(self, latents, embed_cond):
        assert latents.shape[0] == embed_cond.shape[0]
        n = latents.shape[0]
        grad_out = None
        j = 0
        while n > 0:
            b = min(n, self.grad_batch_size)
            lats = latents[j : j + b, :, :, :]
            grad = self.grad_compute(lats, embed_cond[j : j + b, :, :])
            grad_out = grad if grad_out is None else torch.cat([grad_out, grad])
            n -= b
            j += b
        return grad_out
