
from functools import partial
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple
from flow_matching.solver import Solver, ODESolver

import einops
import lightning as L
import torch
from torch import Tensor, nn
# flow_matching
from flow_matching.path.scheduler import CondOTScheduler, Scheduler
from flow_matching.path import AffineProbPath

class FlowMatching(L.LightningModule):
    def __init__(
        self,
        model: torch.nn.Module,
        fm_scheduler: Scheduler,
        std: float,
    ):
        super().__init__()
        self.model: torch.nn.Module = model()
        self.path = AffineProbPath(scheduler=fm_scheduler)
        self.solver = ODESolver(velocity_model=self.model)
        self.std = std

    # https://github.com/facebookresearch/flow_matching/blob/main/examples/2d_flow_matching.ipynb
    def train_step(self, batch):
        
        # sample data (user's responsibility): in this case, (X_0,X_1) ~ pi(X_0,X_1) = N(X_0|0,I)q(X_1)
        x_1 = batch['field']
        x_0 = torch.randn_like(x_1, device=x_1.device) * self.std

        # sample time
        t = torch.rand(x_1.shape[0], device=x_1.device)

        # sample probability path
        path_sample = self.path.sample(t=t, x_0=x_0, x_1=x_1)

        t_forward = path_sample.t
        # t_forward = torch.zeros_like(path_sample.t)
        # flow matching l2 loss
        dx_t = self.model(path_sample.x_t, t_forward, **batch)
        loss = torch.pow(dx_t - path_sample.dx_t, 2).mean() 

        return loss
    
    def inference(self, batch, n_steps=30, return_intermediates=False, **kwargs):
        x_1 = batch['field']
        device = x_1.device
        x_init = torch.randn_like(x_1, dtype=torch.float32, device=device) * self.std
        T = torch.linspace(0, 1, n_steps, device=device)  # sample times
        sol = self.solver.sample(time_grid=T, 
                                 x_init=x_init, 
                                 method='midpoint', 
                                 step_size=1. / n_steps, 
                                 return_intermediates=return_intermediates,
                                 **batch,
                                 **kwargs
                                )
        return sol