import torch


def euler_method(model, input, start_t, end_t, steps=30):
    dt = 1.0 / steps
    x = input.clone()
    s = 1 if end_t > start_t else -1
    for i in range(steps):
        t = start_t + torch.ones((x.size(0), 1), device=x.device) * (
            end_t - start_t
        ) * (i / steps)
        v = model(torch.cat((x, t), dim=1))
        x = x + v * dt * s
    return x


def euler_method_with_conditioning(
    model, input, conditioning, start_t, end_t, steps=30
):
    """Euler method ODE solver with conditioning variables.

    For models that take additional conditioning inputs (e.g., physical parameters
    in Unifoil), this function concatenates the conditioning to the model input
    at each integration step.

    Args:
        model: The flow matching model. Expected input: [state, conditioning, t]
        input: Initial state (batch_size, state_dim)
        conditioning: Conditioning variables (batch_size, cond_dim)
        start_t: Start time (0 or 1)
        end_t: End time (1 or 0)
        steps: Number of integration steps

    Returns:
        Final state after integration (batch_size, state_dim)
    """
    dt = 1.0 / steps
    x = input.clone()
    s = 1 if end_t > start_t else -1
    for i in range(steps):
        t = start_t + torch.ones((x.size(0), 1), device=x.device) * (
            end_t - start_t
        ) * (i / steps)
        v = model(torch.cat((x, conditioning, t), dim=1))
        x = x + v * dt * s
    return x


if __name__ == "__main__":
    import sys

    from uq_diagcfm.utils import get_device
    from uq_diagcfm.models import MLP

    if len(sys.argv) == 2 and sys.argv[1] == "test_euler":
        device = get_device()
        torch.manual_seed(1)

        model = MLP(
            input_dim=6 + 1,
            output_dim=6,
            hidden_dim=64,
            depth=3,
            dropout=0.0,
            activation="ReLU",
        ).to(device)

        input = torch.randn((10, 6), device=device)
        start_t = 0.0
        end_t = 1.0
        print("Input shape:", input.shape)

        out = euler_method(model, input, start_t, end_t)
        print("out shape:", out.shape)

        print(out)
