from model.score import *
from model.text_inv import *
from model.io import *


def slerp(p0, p1, t, DOT_THRESHOLD=0.9995):
    dot = torch.sum(p0 * p1) / (torch.norm(p0) * torch.norm(p1))
    if torch.abs(dot) > DOT_THRESHOLD:
        result = torch.lerp(p0.unsqueeze(0), p1.unsqueeze(0), t)
    else:
        theta = torch.acos(dot)
        sin_theta = torch.sin(theta)
        a = torch.sin((1 - t) * theta) / sin_theta
        b = torch.sin(t * theta) / sin_theta
        result = a * p0.unsqueeze(0) + b * p1.unsqueeze(0)
    return result


class ScoreInterpolator(
    Score_Estimator,
    TextInversion,
    BVP_IO,
):
    def __init__(
        self,
        pipe: SimpleDiffusionPipeline,
        imgA,
        imgB,
        prompt,
        noise_level,
        alpha,
        grad_args,
        output_args,
        tv_args,
        opt_args,
        test_name="test_bvp",
        **kwargs,
    ):
        self.imgA = imgA
        self.imgB = imgB
        self.prompt = prompt
        self.test_name = test_name
        self.alpha = alpha
        self.time_step = pipe.get_t(noise_level, return_single=True)
        self.N = opt_args["N"]
        self.iter_num = opt_args["iter_num"]

        Score_Estimator.__init__(self, pipe=pipe, time_step=self.time_step, **grad_args)
        TextInversion.__init__(self, pipe=pipe, **tv_args)
        BVP_IO.__init__(self, pipe=pipe, noise_level=noise_level, **output_args)

        self.cur_iter = 0

        latA0 = self.pipe.img2latent(self.imgA)
        latB0 = self.pipe.img2latent(self.imgB)

        lat0 = torch.cat([latA0, latB0], dim=0)
        self.embed_cond = self.text_inversion_load(self.prompt, lat0, self.test_name)

        self.latA = self.forward_single(latA0, self.embed_cond)
        self.latB = self.forward_single(latB0, self.embed_cond)

        pointA = self.latA.reshape(-1)
        pointB = self.latB.reshape(-1)

        ts = torch.linspace(0, 1, self.N)[1:-1].to(self.device).unsqueeze(1)
        points = slerp(pointA, pointB, ts)

        self.lats = torch.nn.Parameter(points.reshape(-1, 4, 64, 64))
        self.optimizer = torch.optim.Adam([self.lats], lr=opt_args["lr"])
        self.lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            self.optimizer, T_max=opt_args["iter_num"], eta_min=opt_args["lr_min"]
        )

    def score_diff(self, lats):
        X = torch.cat([self.latA, lats, self.latB], dim=0)
        embed_cond = self.embed_cond.repeat(X.shape[0], 1, 1)
        scrs = self.grad_compute_batch(X, embed_cond)
        scrs = scrs.reshape(-1, 16384)
        scrs_diff = scrs[1:] - scrs[:-1]
        energy = torch.sum(scrs_diff**2, dim=1)
        return 0.5 * torch.sum(energy)

    def step(self):
        self.optimizer.zero_grad()
        loss = self.score_diff(self.lats)

        if self.cur_iter % 5 == 0:
            print(
                "optimise {} iteration: {}, loss: {}".format(
                    self.test_name, self.cur_iter, loss.item()
                )
            )

        loss.backward()
        self.optimizer.step()
        self.lr_scheduler.step()
        self.cur_iter += 1

        return False

    def solve(self):
        embed_cond_args = {"embed_cond": self.embed_cond}
        lats = torch.cat([self.latA, self.lats, self.latB], dim=0)
        with torch.no_grad():
            self.output_interpolation_results(
                0, "start", lats, self.test_name, **embed_cond_args
            )
        for i in range(self.iter_num):
            finish = self.step()
            if finish or i == self.iter_num - 1:
                lats = torch.cat([self.latA, self.lats, self.latB], dim=0)
                with torch.no_grad():
                    self.output_interpolation_results(
                        self.cur_iter, "final", lats, self.test_name, **embed_cond_args
                    )
                break
            torch.cuda.empty_cache()
