import torch
import models
import utils
import os
from tqdm import tqdm
import losses
import geomloss
from torchdiffeq import odeint 

def proximal(w, dims, lam=0.1, eta=0.1):
    """Proximal step for group sparsity"""
    # w shape [j * m1, i]
    wadj = w.view(dims[0], -1, dims[0])  # [j, m1, i]
    tmp = torch.sum(wadj**2, dim=1).pow(0.5) - lam * eta
    alpha = torch.clamp(tmp, min=0)
    v = torch.nn.functional.normalize(wadj, dim=1) * alpha[:, None, :]
    w.data = v.view(-1, dims[0])


def train_denoising_score(s, opt, sigmas, data, batch_size,
                          options = {'iters' : 1000, 'print_iter' : 100, 'checkpoint_iter' : None, 'checkpoint_file' : "score",
                                     'save_final' : False, 'save_file' : 'score', 'outdir' : './'},
                          sample_batch_options = {}):
    trace = []
    X = data['X']
    t = data['t']
    for it in tqdm(range(options['iters'])):
        opt.zero_grad()
        X_batch = utils.sample_batch(X, batch_size = batch_size, **sample_batch_options)
        _t = t.repeat_interleave(batch_size).unsqueeze(1)
        L = 0
        for sigma in sigmas:
            L = L + losses.denoising_score_loss(lambda x : s(_t, x, sigma), X_batch.view(-1, X_batch.shape[-1]), sigma)
        L.backward()
        opt.step()
        trace += [L.item(), ]
        if it % options['print_iter'] == 0:
            print(f"Iteration {it}, loss = {L.item()}")
        if options['checkpoint_iter'] is not None and (it % options['checkpoint_iter'] == 0) and (it > 0):
            torch.save(s.state_dict(), os.path.join(options['outdir'], f"{options['checkpoint_file']}_ckpt_{it}.pt"))
    if options['save_final']:
        torch.save(s.state_dict(), os.path.join(options['outdir'], f"{options['save_file']}_final.pt"))
    return trace 

def get_reg(v, Xs, ts, reg_kind = "vf", **kwargs):
    E = 0
    if isinstance(v, (models.ODEFlowGrowth, models.ODEFlowGrowth_linear)):
        if v.v_net.time_dependent:
            if reg_kind == "vf":
                v_vals = v.v_net.net(torch.cat([Xs[..., 1:], ts[:, None, None].expand((*Xs.shape[:2], 1))], axis = 2))
            elif reg_kind == "jac":
                v_vals = torch.vmap(v.v_net.net.jacobian)(torch.cat([Xs[..., 1:], ts[:, None, None].expand((*Xs.shape[:2], 1))], axis = 2))[..., :-1]
        else:
            if reg_kind == "vf":
                v_vals = v.v_net.net(Xs[..., 1:]).view(Xs[..., 1:].shape)
            elif reg_kind == "jac":
                v_vals = torch.vmap(v.v_net.net.jacobian)(Xs[..., 1:])
        if v.g_net.time_dependent:
            g_vals = v.g_net.net(torch.cat([Xs[..., 1:], ts[:, None, None].expand((*Xs.shape[:2], 1))], axis = 2))
        else:
            g_vals = v.g_net.net(Xs[..., 1:])
        E = torch.vmap(losses.wasserstein_fisher_rao_energy)(v_vals, g_vals, Xs[..., 0].exp(), **kwargs)
    elif isinstance(v, (models.VectorField, models.LinearVectorField)):
        if v.time_dependent:
            if reg_kind == "vf":
                v_vals = v.net(torch.cat([Xs, ts[:, None, None].expand((*Xs.shape[:2], 1))], axis = 2))
            elif reg_kind == "jac":
                v_vals = torch.vmap(v.net.jacobian)(torch.cat([Xs, ts[:, None, None].expand((*Xs.shape[:2], 1))], axis = 2))[..., :-1]
        else:
            if reg_kind == "vf":
                v_vals = v(None, Xs)
            elif reg_kind == "jac":
                v_vals = torch.vmap(v.net.jacobian)(Xs)
        E = torch.vmap(losses.benamou_brenier_energy)(v_vals, torch.full((*v_vals.shape[:2], ), 1 / Xs.shape[1], device = v_vals.device))
    elif isinstance(v, models.MultiplicativeNoiseFlow):
        if v.v.time_dependent:
            if reg_kind == "vf":
                v_vals = v.u.net(torch.cat([Xs, ts[:, None, None].expand((*Xs.shape[:2], 1))], axis = 2)) - v.v.net(torch.cat([Xs, ts[:, None, None].expand((*Xs.shape[:2], 1))], axis = 2))
            elif reg_kind == "jac":
                v_vals = torch.vmap(v.u.net.jacobian)(torch.cat([Xs, ts[:, None, None].expand((*Xs.shape[:2], 1))], axis = 2))[..., :-1] - \
                         torch.vmap(v.v.net.jacobian)(torch.cat([Xs, ts[:, None, None].expand((*Xs.shape[:2], 1))], axis = 2))[..., :-1]
        else:
            if reg_kind == "vf":
                v_vals = v.u.net(Xs) - v.v.net(Xs)
            elif reg_kind == "jac":
                v_vals = torch.vmap(v.u.net.jacobian)(Xs) - torch.vmap(v.v.net.jacobian)(Xs)
        E = torch.vmap(losses.benamou_brenier_energy)(v_vals, torch.full((*v_vals.shape[:2], ), 1 / Xs.shape[1], device = v_vals.device))
    elif isinstance(v, models.MultiplicativeNoiseFlowGrowth):
        if v.v.time_dependent:
            if reg_kind == "vf":
                v_vals = v.u.net(torch.cat([Xs[..., 1:], ts[:, None, None].expand((*Xs.shape[:2], 1))], axis = 2)) - v.v.net(torch.cat([Xs[..., 1:], ts[:, None, None].expand((*Xs.shape[:2], 1))], axis = 2))
            elif reg_kind == "jac":
                v_vals = torch.vmap(v.u.net.jacobian)(torch.cat([Xs[..., 1:], ts[:, None, None].expand((*Xs.shape[:2], 1))], axis = 2))[..., :-1] - \
                         torch.vmap(v.v.net.jacobian)(torch.cat([Xs[..., 1:], ts[:, None, None].expand((*Xs.shape[:2], 1))], axis = 2))[..., :-1]
        else:
            if reg_kind == "vf":
                v_vals = v.u.net(Xs[..., 1:]) - v.v.net(Xs[..., 1:])
            elif reg_kind == "jac":
                v_vals = torch.vmap(v.u.net.jacobian)(Xs[..., 1:]) - torch.vmap(v.v.net.jacobian)(Xs[..., 1:])
        if v.g.time_dependent:
            g_vals = v.g.net(torch.cat([Xs[..., 1:], ts[:, None, None].expand((*Xs.shape[:2], 1))], axis = 2))
        else:
            g_vals = v.g.net(Xs[..., 1:])
        E = torch.vmap(losses.wasserstein_fisher_rao_energy)(v_vals, g_vals, Xs[..., 0].exp(), **kwargs)
    return E

def train_upfi(v, opt, s, sigmas, data, params, batch_size,
                          options = {'iters' : 1000, 'print_iter' : 100, 'reg_kind' : 'vf', 'checkpoint_iter' : None, 'checkpoint_file' : "v_upfi",
                                     'save_final' : False, 'save_file' : 'v_upfi',
                                     'anneal_sigma_iters' : None, 'outdir' : './',
                                     'teacher_forcing_iter' : None},
                          sample_batch_options = {}, odeint_options = {'method' : 'euler'}, samplesloss_options = {'loss' : 'sinkhorn', 'reach' : 5.0}):
    def F_ode_upfi(t, x, sigma):
        return v(t, x) - (D/2)*utils.pad_zeros_upfi(s(t, x[:, 1:], sigma))
    trace = []
    Xs = data['X']
    ts = data['t']
    m_ratios = data['m_ratios']
    D = params['D']
    alpha_wfr = params['alpha_wfr']
    reg_wfr = params['reg_wfr']
    teacher_forcing = False
    # calculate averaged time intervals 
    dt = ts[1:]-ts[:-1]; dt = (torch.hstack([dt[:1], dt]) + torch.hstack([dt, dt[-1:]])) / 2
    # possible annealing
    if options['anneal_sigma_iters'] is not None:
        sigma = sigmas[0]
        r_sigma = torch.exp(torch.log(sigmas[-1] / sigmas[0]) / options['anneal_sigma_iters'])
    else:
        sigma = sigmas[-1]
        r_sigma = 1
    for it in tqdm(range(options['iters'])):
        loss = geomloss.SamplesLoss(**samplesloss_options)
        opt.zero_grad()
        if options['teacher_forcing_iter'] is not None:
            teacher_forcing = it < options['teacher_forcing_iter']
        X_batch = utils.sample_batch_upfi(Xs, m_ratios, batch_size = batch_size, **sample_batch_options)
        if teacher_forcing:
            X_hat = torch.cat([X_batch[0].unsqueeze(0), ] +
                        [odeint(lambda t, x: F_ode_upfi(t, x, sigma), x0, torch.tensor([t0, t1]), **odeint_options)[-1:]
                            for i, (x0, t0, t1) in enumerate(zip(X_batch[:-1], ts[:-1], ts[1:]))])
        else:
            X_hat = odeint(lambda t, x: F_ode_upfi(t, x, sigma), X_batch[0], ts, **odeint_options)
        L = loss(X_hat[..., 0].exp(), X_hat[..., 1:], X_batch[..., 0].exp(), X_batch[..., 1:])
        E = get_reg(v, X_hat, ts, alpha = alpha_wfr, reg_kind = options['reg_kind'])
        loss_fit, loss_reg = (L / m_ratios).mean(), (E / m_ratios * dt).mean()
        loss_total = loss_fit + reg_wfr*loss_reg
        if isinstance(v.v_net, models.NGMVectorField):
            loss_total += params['reg_ngm_l2']*v.v_net.net.net.l2_reg() + params['reg_ngm_l1']*v.v_net.net.net.fc1_reg()
        loss_total.backward()
        opt.step()
        if isinstance(v.v_net, models.NGMVectorField):
            proximal(v.v_net.net.net.fc1.weight, v.v_net.net.net.dims, lam=v.v_net.net.net.GL_reg, eta=0.01)
        trace += [loss_total.item(), ]
        if it % options['print_iter'] == 0:
            print(f"Iteration {it}, loss_total = {loss_total.item()}, fit = {loss_fit.item()}, reg_wfr = {loss_reg.item()}, sigma = {sigma}, teacher_forcing = {teacher_forcing}")
        if options['checkpoint_iter'] is not None and (it % options['checkpoint_iter'] == 0) and (it > 0):
            torch.save(v.state_dict(), os.path.join(options['outdir'], f"{options['checkpoint_file']}_ckpt_{it}.pt"))
        sigma = max(sigmas[-1], sigma*r_sigma)
    # save final 
    if options['save_final']:
        torch.save(v.state_dict(), os.path.join(options['outdir'], f"{options['save_file']}_final.pt"))
    return trace 

def train_tigon(v, opt, data, params, batch_size,
                          options = {'iters' : 1000, 'print_iter' : 100, 'reg_kind' : 'vf', 'checkpoint_iter' : None, 'checkpoint_file' : "v_tigon",
                                     'save_final' : False, 'save_file' : 'v_tigon', 'outdir' : './',
                                     'teacher_forcing_iter' : None},
                          sample_batch_options = {}, odeint_options = {'method' : 'euler'}, samplesloss_options = {'loss' : 'sinkhorn', 'reach' : 5.0}):
    trace = []
    Xs = data['X']
    ts = data['t']
    m_ratios = data['m_ratios']
    alpha_wfr = params['alpha_wfr']
    reg_wfr = params['reg_wfr']
    teacher_forcing = False
    # calculate averaged time intervals 
    dt = ts[1:]-ts[:-1]; dt = (torch.hstack([dt[:1], dt]) + torch.hstack([dt, dt[-1:]])) / 2
    loss = geomloss.SamplesLoss(**samplesloss_options)
    for it in tqdm(range(options['iters'])):
        opt.zero_grad()
        if options['teacher_forcing_iter'] is not None:
            teacher_forcing = it < options['teacher_forcing_iter']
        X_batch = utils.sample_batch_upfi(Xs, m_ratios, batch_size = batch_size, **sample_batch_options)
        if teacher_forcing:
            X_hat = torch.cat([X_batch[0].unsqueeze(0), ] +
                        [odeint(v, x0, torch.tensor([t0, t1]), **odeint_options)[-1:]
                            for i, (x0, t0, t1) in enumerate(zip(X_batch[:-1], ts[:-1], ts[1:]))])
        else:
            X_hat = odeint(v, X_batch[0], ts, **odeint_options)
        L = loss(X_hat[..., 0].exp(), X_hat[..., 1:], X_batch[..., 0].exp(), X_batch[..., 1:])
        E = get_reg(v, X_hat, ts, alpha = alpha_wfr, reg_kind = options['reg_kind'])
        loss_fit, loss_reg = (L / m_ratios).mean(), (E / m_ratios * dt).mean()
        loss_total = loss_fit + reg_wfr*loss_reg
        if isinstance(v.v_net, models.NGMVectorField):
            loss_total += params['reg_ngm_l2']*v.v_net.net.net.l2_reg() + params['reg_ngm_l1']*v.v_net.net.net.fc1_reg()
        loss_total.backward()
        opt.step()
        if isinstance(v.v_net, models.NGMVectorField):
            proximal(v.v_net.net.net.fc1.weight, v.v_net.net.net.dims, lam=v.v_net.net.net.GL_reg, eta=0.01)
        trace += [loss_total.item(), ]
        if it % options['print_iter'] == 0:
            print(f"Iteration {it}, loss_total = {loss_total.item()}, fit = {loss_fit.item()}, reg_wfr = {loss_reg.item()}, teacher_forcing = {teacher_forcing}")
        if options['checkpoint_iter'] is not None and (it % options['checkpoint_iter'] == 0) and (it > 0):
            torch.save(v.state_dict(), os.path.join(options['outdir'], f"{options['checkpoint_file']}_ckpt_{it}.pt"))
    # save final 
    if options['save_final']:
        torch.save(v.state_dict(), os.path.join(options['outdir'], f"{options['save_file']}_final.pt"))
    return trace 

# Special case for MultiplicativeNoiseFlow
def train_multinoise_pfi(v, opt, s, sigmas, data, params, batch_size,
                          options = {'iters' : 1000, 'print_iter' : 100, 'checkpoint_iter' : None, 'checkpoint_file' : "v_pfi_mult",
                                     'save_final' : False, 'save_file' : 'v_pfi_mult',
                                     'anneal_sigma_iters' : None, 'outdir' : './'},
                          sample_batch_options = {}, odeint_options = {'method' : 'euler'}, samplesloss_options = {'loss' : 'sinkhorn'}):
    trace = []
    Xs = data['X']
    ts = data['t']
    dt = ts[1:]-ts[:-1]; dt = (torch.hstack([dt[:1], dt]) + torch.hstack([dt, dt[-1:]])) / 2
    D = params['D']
    reg = params['reg']
    loss = geomloss.SamplesLoss(**samplesloss_options)
    teacher_forcing = False
    # possible annealing
    if options['anneal_sigma_iters'] is not None:
        sigma = sigmas[0]
        r_sigma = torch.exp(torch.log(sigmas[-1] / sigmas[0]) / options['anneal_sigma_iters'])
    else:
        sigma = sigmas[-1]
        r_sigma = 1
    for it in tqdm(range(options['iters'])):
        opt.zero_grad()
        if options['teacher_forcing_iter'] is not None:
            teacher_forcing = it < options['teacher_forcing_iter']
        X_batch = utils.sample_batch(Xs, batch_size = batch_size, **sample_batch_options)
        if teacher_forcing:
            X_hat = torch.cat([X_batch[0].unsqueeze(0), ] +
                        [odeint(v, x0, torch.tensor([t0, t1]), **odeint_options)[-1:]
                            for i, (x0, t0, t1) in enumerate(zip(X_batch[:-1], ts[:-1], ts[1:]))])
        else:
            X_hat = odeint(v, X_batch[0], ts, **odeint_options)
        L = loss(X_hat, X_batch)
        E = get_reg(v, X_hat, ts, reg_kind = options['reg_kind'])
        loss_fit, loss_reg = L.mean(), (E * dt).mean()
        loss_total = loss_fit + reg*loss_reg
        loss_total.backward()
        opt.step()
        trace += [loss_total.item(), ]
        if it % options['print_iter'] == 0:
            print(f"Iteration {it}, loss_total = {loss_total.item()}, fit = {loss_fit.item()}, reg_wfr = {loss_reg.item()}, sigma = {sigma}, teacher_forcing = {teacher_forcing}")
        if options['checkpoint_iter'] is not None and (it % options['checkpoint_iter'] == 0) and (it > 0):
            torch.save(v.state_dict(), os.path.join(options['outdir'], f"{options['checkpoint_file']}_ckpt_{it}.pt"))
        sigma = max(sigmas[-1], sigma*r_sigma)
    # save final 
    if options['save_final']:
        torch.save(v.state_dict(), os.path.join(options['outdir'], f"{options['save_file']}_final.pt"))
    return trace 

def train_multinoise_upfi(v, opt, s, sigmas, data, params, batch_size,
                          options = {'iters' : 1000, 'print_iter' : 100, 'checkpoint_iter' : None, 'checkpoint_file' : "v_upfi_mult",
                                     'save_final' : False, 'save_file' : 'v_upfi_mult',
                                     'anneal_sigma_iters' : None, 'outdir' : './'},
                          sample_batch_options = {}, odeint_options = {'method' : 'euler'}, samplesloss_options = {'loss' : 'sinkhorn', 'reach' : 5.0}):
    trace = []
    Xs = data['X']
    ts = data['t']
    m_ratios = data['m_ratios']
    dt = ts[1:]-ts[:-1]; dt = (torch.hstack([dt[:1], dt]) + torch.hstack([dt, dt[-1:]])) / 2
    D = params['D']
    alpha_wfr = params['alpha_wfr']
    reg_wfr = params['reg_wfr']
    loss = geomloss.SamplesLoss(**samplesloss_options)
    teacher_forcing = False
    # possible annealing
    if options['anneal_sigma_iters'] is not None:
        sigma = sigmas[0]
        r_sigma = torch.exp(torch.log(sigmas[-1] / sigmas[0]) / options['anneal_sigma_iters'])
    else:
        sigma = sigmas[-1]
        r_sigma = 1
    for it in tqdm(range(options['iters'])):
        opt.zero_grad()
        if options['teacher_forcing_iter'] is not None:
            teacher_forcing = it < options['teacher_forcing_iter']
        X_batch = utils.sample_batch_upfi(Xs, m_ratios, batch_size = batch_size, **sample_batch_options)
        if teacher_forcing:
            X_hat = torch.cat([X_batch[0].unsqueeze(0), ] +
                        [odeint(v, x0, torch.tensor([t0, t1]), **odeint_options)[-1:]
                            for i, (x0, t0, t1) in enumerate(zip(X_batch[:-1], ts[:-1], ts[1:]))])
        else:
            X_hat = odeint(v, X_batch[0], ts, **odeint_options)
        L = loss(X_hat[..., 0].exp(), X_hat[..., 1:], X_batch[..., 0].exp(), X_batch[..., 1:])
        E = get_reg(v, X_hat, ts, reg_kind = options['reg_kind'], alpha = alpha_wfr)
        loss_fit, loss_reg = (L / m_ratios).mean(), (E / m_ratios * dt).mean()
        loss_total = loss_fit + reg_wfr*loss_reg
        loss_total.backward()
        opt.step()
        trace += [loss_total.item(), ]
        if it % options['print_iter'] == 0:
            print(f"Iteration {it}, loss_total = {loss_total.item()}, fit = {loss_fit.item()}, reg_wfr = {loss_reg.item()}, sigma = {sigma}, teacher_forcing = {teacher_forcing}")
        if options['checkpoint_iter'] is not None and (it % options['checkpoint_iter'] == 0) and (it > 0):
            torch.save(v.state_dict(), os.path.join(options['outdir'], f"{options['checkpoint_file']}_ckpt_{it}.pt"))
        sigma = max(sigmas[-1], sigma*r_sigma)
    # save final 
    if options['save_final']:
        torch.save(v.state_dict(), os.path.join(options['outdir'], f"{options['save_file']}_final.pt"))
    return trace 

def train_ode(v, opt, data, params, batch_size,
                          options = {'iters' : 1000, 'print_iter' : 100, 'checkpoint_iter' : None, 'checkpoint_file' : "v_ode",
                                     'save_final' : False, 'save_file' : 'v_ode', 'outdir' : './', 'teacher_forcing_iter' : None},
                          sample_batch_options = {}, odeint_options = {'method' : 'euler'}, samplesloss_options = {'loss' : 'sinkhorn', 'reach' : 5.0}):
    trace = []
    Xs = data['X']
    ts = data['t']
    m_ratios = data['m_ratios']
    dt = ts[1:]-ts[:-1]; dt = (torch.hstack([dt[:1], dt]) + torch.hstack([dt, dt[-1:]])) / 2
    alpha_wfr = params['alpha_wfr']
    reg_wfr = params['reg_wfr']
    teacher_forcing = False
    loss = geomloss.SamplesLoss(**samplesloss_options)
    # possible annealing
    for it in tqdm(range(options['iters'])):
        opt.zero_grad()
        if options['teacher_forcing_iter'] is not None:
            teacher_forcing = it < options['teacher_forcing_iter']
        X_batch = utils.sample_batch_upfi(Xs, m_ratios, batch_size = batch_size, **sample_batch_options)
        if teacher_forcing:
            X_hat = torch.cat([X_batch[0].unsqueeze(0), ] +
                        [odeint(v, x0, torch.tensor([t0, t1]), **odeint_options)[-1:]
                            for i, (x0, t0, t1) in enumerate(zip(X_batch[:-1], ts[:-1], ts[1:]))])
        else:
            X_hat = odeint(v, X_batch[0], ts, **odeint_options)
        L = loss(X_hat[..., 0].exp(), X_hat[..., 1:], X_batch[..., 0].exp(), X_batch[..., 1:])
        # calculate v and g
        v_vals = torch.cat([v.v(t, x).unsqueeze(0) for (t, x) in zip(ts, X_hat)])
        g_vals = torch.cat([v.g(t, x).unsqueeze(0) for (t, x) in zip(ts, X_hat)])
        E = torch.vmap(losses.wasserstein_fisher_rao_energy)(v_vals, g_vals, X_hat[..., 0].exp(), alpha = alpha_wfr)
        loss_fit = (L / m_ratios).mean()
        loss_reg = (E / m_ratios * dt).mean()
        loss_total = loss_fit + reg_wfr*loss_reg 
        loss_total.backward()
        opt.step()
        trace += [loss_total.item(), ]
        if it % options['print_iter'] == 0:
            print(f"Iteration {it}, loss_total = {loss_total.item()}, fit = {loss_fit.item()}, reg_wfr = {loss_reg.item()}, teacher_forcing = {teacher_forcing}")
        if options['checkpoint_iter'] is not None and (it % options['checkpoint_iter'] == 0) and (it > 0):
            torch.save(v.state_dict(), os.path.join(options['outdir'], f"{options['checkpoint_file']}_ckpt_{it}.pt"))
    # save final 
    if options['save_final']:
        torch.save(v.state_dict(), os.path.join(options['outdir'], f"{options['save_file']}_final.pt"))
    return trace 

def train_pfi(v, opt, s, sigmas, data, params, batch_size,
                          options = {'iters' : 1000, 'print_iter' : 100, 'reg_kind' : 'vf', 'checkpoint_iter' : None, 'checkpoint_file' : "v_pfi",
                                     'save_final' : False, 'save_file' : 'v_pfi',
                                     'anneal_sigma_iters' : None, 'outdir' : './',
                                     'teacher_forcing_iter' : None},
                          sample_batch_options = {}, odeint_options = {'method' : 'euler'}, samplesloss_options = {'loss' : 'sinkhorn'}):
    trace = []
    Xs = data['X']
    ts = data['t']
    dt = ts[1:]-ts[:-1]; dt = (torch.hstack([dt[:1], dt]) + torch.hstack([dt, dt[-1:]])) / 2
    D = params['D']
    reg = params['reg']
    teacher_forcing = False
    # possible annealing
    if options['anneal_sigma_iters'] is not None:
        sigma = sigmas[0]
        r_sigma = torch.exp(torch.log(sigmas[-1] / sigmas[0]) / options['anneal_sigma_iters'])
    else:
        sigma = sigmas[-1]
        r_sigma = 1
    for it in tqdm(range(options['iters'])):
        loss = geomloss.SamplesLoss(**samplesloss_options)
        opt.zero_grad()
        if options['teacher_forcing_iter'] is not None:
            teacher_forcing = it < options['teacher_forcing_iter']
        xs_batch = utils.sample_batch(Xs, batch_size, **sample_batch_options)
        if teacher_forcing:
            xs_t = torch.cat([xs_batch[0].unsqueeze(0), ] +
                        [odeint(lambda t, x: utils.get_flow(v(t, x), s(t, x, sigma), D), x0, torch.tensor([t0, t1]), **odeint_options)[-1:]
                            for i, (x0, t0, t1) in enumerate(zip(xs_batch[:-1], ts[:-1], ts[1:]))])
        else:
            xs_t = odeint(lambda t, x: utils.get_flow(v(t, x), s(t, x, sigma), D), xs_batch[0, ...], ts, **odeint_options)
        L = loss(xs_t, xs_batch)
        # if v.time_dependent:
        #     v_vals = v.net(torch.cat([xs_t, ts[:, None, None].expand((*xs_t.shape[:2], 1))], axis = 2))
        # else:
        #     v_vals = v(_, xs_t)
        # E = torch.vmap(losses.benamou_brenier_energy)(v_vals, torch.full((*v_vals.shape[:2], ), 1 / batch_size, device = v_vals.device))
        E = get_reg(v, xs_t, ts, reg_kind = options['reg_kind'])
        loss_fit = L.mean()
        loss_reg = (E * dt).mean()
        loss_total = loss_fit + reg*loss_reg 
        if isinstance(v, models.NGMVectorField):
            loss_total += params['reg_ngm_l2']*v.net.net.l2_reg() + params['reg_ngm_l1']*v.net.net.fc1_reg()
        loss_total.backward()
        opt.step()
        if isinstance(v, models.NGMVectorField):
            proximal(v.net.net.fc1.weight, v.net.net.dims, lam=v.net.net.GL_reg, eta=0.01)
        trace += [loss_total.item(), ]
        if it % options['print_iter'] == 0:
            print(f"Iteration {it}, loss_total = {loss_total.item()}, fit = {loss_fit.item()}, reg_wfr = {loss_reg.item()}, sigma = {sigma}, teacher_forcing = {teacher_forcing}")
        if options['checkpoint_iter'] is not None and (it % options['checkpoint_iter'] == 0) and (it > 0):
            torch.save(v.state_dict(), os.path.join(options['outdir'], f"{options['checkpoint_file']}_ckpt_{it}.pt"))
        sigma = max(sigmas[-1], sigma*r_sigma)
    # save final 
    if options['save_final']:
        torch.save(v.state_dict(), os.path.join(options['outdir'], f"{options['save_file']}_final.pt"))
    return trace 
