import torch
import torch.nn as nn
import torch.nn.functional as F

# from zuko.utils import odeint
from torchdiffeq import odeint_adjoint as odeint


def noise_to_params(vector_field, noise, ode_kwargs=None):
    if ode_kwargs is None:
        ode_kwargs = dict(rtol=1e-5, atol=1e-5, method="euler", 
                          options=dict(step_size=1/100))
    params = odeint(
        vector_field,
        noise,
        # 0.0,
        torch.tensor([0.0, 1.0], device=noise[0].device, dtype=noise[0].dtype),
        # phi=self.parameters(),
        # method="euler", # "dopri5",
        # rtol=1e-5,
        # atol=1e-5,
        adjoint_params=(),
        **ode_kwargs
        # options=dict(step_size=1/100),
    )
    params = [p[-1] for p in params]
    weights, biases = params[:len(params)//2], params[len(params)//2:]
    return weights, biases
    


def expand_as(x, y):
    if x.dim() == y.dim():
        return x
    for _ in range(x.dim() - y.dim()):
        y = y.unsqueeze(-1)
    return y

def flow_matching_loss(
    vector_field, x, t=None, noise=None, sigma_min=1e-4
):
    weights, biases = x
    if t is None:
        t = torch.rand(weights[0].shape[0], device=weights[0].device)
    if noise is None:
        weights_noise = [torch.randn_like(w) for w in weights]
        biases_noise = [torch.randn_like(b) for b in biases]
    else:
        weights_noise, biases_noise = noise

    y_weights = []
    y_biases = []
    u_weights = []
    u_biases = []
    for i in range(len(weights)):
        y_weights.append(
            expand_as(weights[i], t) * weights[i] + 
            ((1 - (1 - sigma_min) * expand_as(weights[i], t))) * weights_noise[i])
        y_biases.append(
            expand_as(biases[i], t) * biases[i] + 
            ((1 - (1 - sigma_min) * expand_as(biases[i], t))) * biases_noise[i])
        u_weights.append(weights[i] - (1 - sigma_min) * weights_noise[i])
        u_biases.append(biases[i] - (1 - sigma_min) * biases_noise[i])

    u = torch.cat([torch.flatten(w, 1) for w in u_weights + u_biases], dim=1)

    y_hat_weights, y_hat_biases = vector_field((y_weights, y_biases), t)
    y_hat = torch.cat([torch.flatten(w, 1) for w in y_hat_weights + y_hat_biases], dim=1)

    return (y_hat - u).square().mean(dim=1)


class VFWrapper(nn.Module):
    def __init__(self, vector_field):
        super().__init__()
        self.vector_field = vector_field

    def forward(self, t, x):
        t = t.expand(x[0].shape[0])
        weights, biases = self.vector_field((x[:len(x)//2], x[len(x)//2:]), t)
        return *weights, *biases