import torch
from torch import nn
from .mlp import MLP
from solver import RiemannianODESolver
from flow_matching.utils.manifolds import Sphere


class RiemannianWrapper(nn.Module):
    def __init__(self, velocity_model, manifold):
        super().__init__()
        self.velocity_model = velocity_model
        self.manifold = manifold

    def forward(self, x, t, y):
        t = t.view(-1)
        x = self.manifold.projx(x)
        v = self.velocity_model(x, t, y)
        v = self.manifold.proju(x, v)
        return v


class REPVLM(nn.Module):
    def __init__(self):
        super().__init__()
        self.manifold = Sphere()
        self.tan_vel = RiemannianWrapper(
            MLP(input_dim=512, hidden_dim=1024, depth=6), self.manifold)

        def velo_model(mode="image"):
            def func(x, t):
                if mode == "image":
                    labels = torch.zeros(x.shape[0], dtype=torch.long).to(x.device)
                else:
                    labels = torch.ones(x.shape[0], dtype=torch.long).to(x.device)
                return self.tan_vel(x, t, labels)
            return func

        # initialize solvers
        self.image_solver = RiemannianODESolver(
            velocity_model=velo_model(mode="image"), manifold=self.manifold)
        self.text_solver = RiemannianODESolver(
            velocity_model=velo_model(mode="text"), manifold=self.manifold)

        self.num_steps = 5

    def forward(self, xt, t, y):
        vt = self.tan_vel(xt, t, y)
        return vt

    # applications
    def adapt_image(self, z_i):
        # uncertainty estimation
        z_i = self.manifold.projx(z_i)
        p0_log_density = lambda x: 817.0

        _, log_p1 = self.image_solver.compute_likelihood(
            x_1=z_i, method='euler',
            step_size=1.0 / self.num_steps,
            log_p0=p0_log_density)
        return z_i, -log_p1

    def adapt_text(self, z_t):
        # uncertainty estimation
        z_t = self.manifold.projx(z_t)
        p0_log_density = lambda x: 817.0

        _, log_p1 = self.text_solver.compute_likelihood(
            x_1=z_t, method='euler',
            step_size=1.0 / self.num_steps,
            log_p0=p0_log_density)
        return z_t, -log_p1


if __name__ == "__main__":
    model = REPVLM()
    dim=512
    x_i = torch.randn(4, dim)
    t = torch.rand((4,))
    y = torch.zeros(4, dtype=torch.long)
    v_i = model(x_i, t, y)
    print(v_i.shape)