import torch
from torch import nn
from .mlp import MLP
from flow_matching.solver import ODESolver


class Wrapper(nn.Module):
    def __init__(self, velocity_model):
        super().__init__()
        self.velocity_model = velocity_model

    def forward(self, x, t, y):
        t = t.view(-1)
        v = self.velocity_model(x, t, y)
        return v


class EucREPVLM(nn.Module):
    def __init__(self, gaussian_base=True):
        super().__init__()
        self.vel = Wrapper(MLP(input_dim=512, hidden_dim=1024, depth=6))
        self.gaussian_base = gaussian_base

        def velo_model(mode="image"):
            def func(x, t):
                if mode == "image":
                    labels = torch.zeros(x.shape[0], dtype=torch.long).to(x.device)
                else:
                    labels = torch.ones(x.shape[0], dtype=torch.long).to(x.device)
                return self.vel(x, t, labels)
            return func

        # initialize solvers
        self.image_solver = ODESolver(
            velocity_model=velo_model(mode="image"))
        self.text_solver = ODESolver(
            velocity_model=velo_model(mode="text"))

        self.num_steps = 50

    def forward(self, xt, t, y):
        vt = self.vel(xt, t, y)
        return vt

    # applications
    def adapt_image(self, z_i):
        # uncertainty estimation
        z_i = z_i / z_i.norm(dim=-1, keepdim=True)
        def p0_log_density(x):
            # if 
            if self.gaussian_base:
                return -0.5 * x.norm(dim=-1) ** 2
            # Hyperbolic base
            else:
                return -817.0

        _, log_p1 = self.image_solver.compute_likelihood(
            x_1=z_i, method='rk4',
            step_size=1.0 / self.num_steps,
            log_p0=p0_log_density)
        return z_i, -log_p1

    def adapt_text(self, z_t):
        # uncertainty estimation
        z_t = z_t / z_t.norm(dim=-1, keepdim=True)
        def p0_log_density(x):
            # if Gaussian base
            if self.gaussian_base:
                return -0.5 * x.norm(dim=-1) ** 2
            # Hyperbolic base
            else:
                return -817.0

        _, log_p1 = self.text_solver.compute_likelihood(
            x_1=z_t, method='rk4',
            step_size=1.0 / self.num_steps,
            log_p0=p0_log_density)
        return z_t, -log_p1

# --- Test ---
if __name__ == "__main__":
    model = EucREPVLM()
    xt = torch.randn(4, 512)
    # x_t = torch.randn(4, 512)
    t = torch.rand((4,))
    y = torch.randint(0, 2, (4,))

    # move to device
    DEVICE = f"cuda:{torch.cuda.device_count() - 1}"
    model = model.to(DEVICE)
    xt = xt.to(DEVICE)
    t = t.to(DEVICE)
    y = y.to(DEVICE)

    v_i = model(xt, t, y)
    print(v_i.shape)

    # test adapt_image
    z_i = torch.randn(4, 512).to(DEVICE)
    z_i, log_p1 = model.adapt_image(z_i)
    print(z_i.shape)
    print(log_p1.shape)

    # test adapt_text
    z_t = torch.randn(4, 512).to(DEVICE)
    z_t, log_p1 = model.adapt_text(z_t)
    print(z_t.shape)
    print(log_p1.shape)