from .base import SolverBase
import torch


class DDPMBased(SolverBase):
    def __init__(
        self,
        pred_x0_fn,
        x1_forw,
        x_gt,
        mask,
        corrupt_method,
        ot_ode,
        cond_fn,
        guidance_classes,
        p_posterior_fn,
        step_size=1.0,
        **kwargs
    ):
        super().__init__()

        self.pred_x0_fn = pred_x0_fn
        self.x1_forw = x1_forw
        self.x_gt = x_gt
        self.mask = mask
        self.corrupt_method = corrupt_method
        self.ot_ode = ot_ode
        self.cond_fn = cond_fn
        self.guidance_classes = guidance_classes
        self.p_posterior_fn = p_posterior_fn
        self.step_size = step_size
        
        self.photo_number = kwargs.get("photo_number", None)

    def step(self, prev_step, step, xt):
        """Predict x_{nprev} | x_n, x_0 with guidance"""

        xt.requires_grad_()
        pred_x0 = self.pred_x0_fn(xt.float(), step)

        corrupt_x0_forw, _ = self.corrupt_method(pred_x0)

        residual = corrupt_x0_forw - self.x1_forw
        residual_norm = torch.linalg.norm(residual) ** 2

        cond_fn_kwargs = {
            "pred_xstart": pred_x0.float(),
            "comps": -self.step_size * residual_norm,
            "mask": self.mask,
        }
        
        if self.photo_number is not None:
            cond_fn_kwargs["photo_number"] = self.photo_number

        cond_grad = self.cond_fn(
            x=xt, t=[step], y=self.guidance_classes, gt=self.x_gt, **cond_fn_kwargs
        )
        xt = xt + cond_grad
        del cond_grad

        xt_prev, _ = self.p_posterior_fn(
            prev_step, step, xt, pred_x0, ot_ode=self.ot_ode, verbose=True
        )

        self.log_xt_update(step, xt - xt_prev)
        
        return xt_prev, pred_x0