# Core algorithms for Physics-Constrained Flow Matching (PCFM)
# Herein, we implement various sampling strategies which will be used throughout our experiments: PCFM, vanilla flow matching, ECI, D-Flow, and DiffusionPDE for both 1D and 2D PDEs (NS)
# Constraints used in PCFM are specified via Residuals and Residuals2d for different PDE problems 

import torch
from torch.func import vmap, jacrev
from torch import nn
from tqdm import tqdm

import gc
from torchdiffeq import odeint_adjoint as odeint
#from torchdiffeq import odeint

def compute_jacobian(fn, inputs):
    def fn_flat(x):
        return fn(x).flatten()
    J = jacrev(fn_flat)(inputs)
    M = J.shape[0]
    N = inputs.numel()
    return J.reshape(M, N)

# Final projection step in PCFM algorithm
def fast_project_batched(xi_batch, h_func, max_iter=1):
    B, n = xi_batch.shape

    def newton_step(u, xi):
        h_val = h_func(u)
        if h_val.ndim == 1:
            h_val = h_val.unsqueeze(-1)
        J = jacrev(h_func,chunk_size=len(h_val)//4)(u)
        if J.ndim == 1:
            J = J.unsqueeze(0)
        delta = (xi - u).unsqueeze(-1)
        JJt = J @ J.transpose(-2, -1)
        rhs = J @ delta + h_val
        lambda_ = torch.linalg.solve(JJt, rhs)
        du = delta - J.transpose(-2, -1) @ lambda_
        return u + du.squeeze(-1)

    def loop(xi):
        u = xi.clone()
        gc.collect()
        torch.cuda.empty_cache()
        for _ in range(max_iter):
            u = newton_step(u, xi)
        return u

    return vmap(loop)(xi_batch)

# Final projection step in PCFM (chunked)
def fast_project_batched_chunk(xi_batch, h_func, max_iter=1, chunk_size=16):
    B, n = xi_batch.shape
    results = []
    for start in range(0, B, chunk_size):
        xi_chunk = xi_batch[start:start + chunk_size]

        def newton_step(u, xi):
            h_val = h_func(u)
            if h_val.ndim == 1:
                h_val = h_val.unsqueeze(-1)
            J = jacrev(h_func, chunk_size=max(1, len(h_val)//4))(u)
            if J.ndim == 1:
                J = J.unsqueeze(0)
            delta = (xi - u).unsqueeze(-1)
            JJt = J @ J.transpose(-2, -1)
            rhs = J @ delta + h_val
            lambda_ = torch.linalg.solve(JJt, rhs)
            du = delta - J.transpose(-2, -1) @ lambda_
            return u + du.squeeze(-1)

        def loop(xi):
            u = xi.clone()
            for _ in range(max_iter):
                u = newton_step(u, xi)
            return u

        results.append(vmap(loop)(xi_chunk))
        del xi_chunk
        gc.collect()
        torch.cuda.empty_cache()

    return torch.cat(results, dim=0)



# Residuals used for constraints projection in PCFM (Heat, Reaction-Diffusion, Burgers)
class Residuals:
    def __init__(self, data, x, t_grid, dx=None, dt=None, nx=None, nt=None, rho=None, nu=None, bc=None, left_bc=None):
        device = data.device
        self.data = data
        self.x = x.to(device)
        self.t_grid = t_grid.to(device)
        self.dx = (dx if dx is not None else (x[1] - x[0])).to(device)
        self.dt = (dt if dt is not None else (t_grid[1] - t_grid[0])).to(device)
        self.nx = nx
        self.nt = nt
        self.rho = rho
        self.nu = nu
        self.bc = bc
        self.left_bc = left_bc

    def ic_residual(self, u_flat):
        u = u_flat.view(self.nx, self.nt)
        ic_target = self.data[0][:, 0].to(u.device)
        return u[:, 0] - ic_target

    def mass_residual_heat(self, u_flat):
        u = u_flat.view(self.nx, self.nt)
        dx = self.dx.to(u.device)
        mass_0 = torch.sum(u[:, 0]) * dx
        mass_t = torch.sum(u, dim=0) * dx
        return mass_t[1:] - mass_0

    def mass_residual_rd(self, u_flat):
        sol = u_flat.view(self.nx, self.nt).T  # [nt, nx]
        device = sol.device
        dx = self.dx.to(device)
        dt = (self.t_grid[1:] - self.t_grid[:-1]).to(device)

        mass = sol.sum(dim=1) * dx
        S = self.rho * (sol * (1 - sol)).sum(dim=1) * dx
        S_mid = 0.5 * (S[:-1] + S[1:])
        S_cum = torch.cat([torch.zeros(1, device=device), torch.cumsum(S_mid * dt, dim=0)], dim=0)

        gL_t = -self.nu * (-25*sol[:, 0] + 48*sol[:, 1] - 36*sol[:, 2] + 16*sol[:, 3] - 3*sol[:, 4]) / (12 * dx)
        gR_t = -self.nu * (25*sol[:, -1] - 48*sol[:, -2] + 36*sol[:, -3] - 16*sol[:, -4] + 3*sol[:, -5]) / (12 * dx)
        F = gL_t - gR_t
        F_mid = 0.5 * (F[:-1] + F[1:])
        F_cum = torch.cat([torch.zeros(1, device=device), torch.cumsum(F_mid * dt, dim=0)], dim=0)

        return mass - (mass[0] + S_cum + F_cum)
        
    def bc_residual_burgers(self, u_flat, start_step=0):
        u = u_flat.view(self.nx, self.nt).T   # [nt,nx]
        # Dirichlet left @ x=0
        resL = u[start_step:, 0] - self.left_bc.to(u.device) 
        # Neumann right zero-gradient @ x=-1
        resR = u[start_step:, -1] - u[start_step:, -2]
        return torch.cat([resL, resR], dim=0)

    def mass_residual_burgers(self, u_flat):
        u = u_flat.view(self.nx, self.nt).T  # shape: [nt, nx]
        mass = u.sum(dim=1) * self.dx.to(u.device)
        f = 0.5 * u**2
        flux = f[:, -1] - f[:, 0]
        flux_mid = 0.5 * (flux[:-1] + flux[1:])
        flux_cum = torch.cat([
            torch.zeros(1, device=u.device),
            torch.cumsum(flux_mid * self.dt.to(u.device), dim=0)
        ], dim=0)
        return mass - (mass[0] - flux_cum)

    def godunov_flux(self, uL, uR):
        fL = 0.5 * uL ** 2
        fR = 0.5 * uR ** 2
        s = 0.5 * (uL + uR)
        flux_rarefaction = torch.minimum(fL, fR)
        flux_shock = torch.where(s > 0, fL, fR)
        is_shock = (uL > uR)
        return torch.where(is_shock, flux_shock, flux_rarefaction)

    def burgers_local_multistep_residual(self, u_flat, k=5):
        u = u_flat.view(self.nx, self.nt)
        residuals = []
        for n in range(k):
            if n + 1 >= self.nt:
                break
            u_prev = u[:, n]
            u_next = u[:, n + 1]
            uL = u_prev[:-1]
            uR = u_prev[1:]
            F = self.godunov_flux(uL, uR)
            flux_diff = F[1:] - F[:-1]
            rhs = u_prev[1:-1] - (self.dt / self.dx) * flux_diff
            res = u_next[1:-1] - rhs
            residuals.append(res)
        return torch.cat(residuals, dim=0)

    def full_residual_heat(self, u_flat):
        return torch.cat([self.ic_residual(u_flat), self.mass_residual_heat(u_flat)], dim=0)

    def full_residual_rd(self, u_flat):
        ic = self.ic_residual(u_flat)
        mass = self.mass_residual_rd(u_flat)[1:]
        return torch.cat([ic, mass], dim=0)
    
    # for IC case 
    def full_residual_burgers(self, u_flat, k=5):
        ic = self.ic_residual(u_flat)
        dyn = self.burgers_local_multistep_residual(u_flat, k=k)
        mass = self.mass_residual_burgers(u_flat)[1:]
        return torch.cat([ic, dyn, mass], dim=0)

    # for BC case 
    def full_residual_burgers2(self, u_flat, start_step=1):
        bc   = self.bc_residual_burgers(u_flat, start_step)
        mass = self.mass_residual_burgers(u_flat)[1:]
        return torch.cat([bc, mass], dim=0)


# Residuals used for constraints projection in PCFM (Navier-Stokes)
class Residuals2D:
    def __init__(self, data, x, y, t_grid, dx=None, dy=None, dt=None, nx=None, ny=None, nt=None, rho=None, nu=None):
        device = data.device
        self.data = data
        self.x = x.to(device)
        self.y = y.to(device)
        self.t_grid = t_grid.to(device)
        self.dx = (dx if dx is not None else (x[1] - x[0])).to(device)
        self.dy = (dy if dy is not None else (y[1] - y[0])).to(device)
        self.dt = (dt if dt is not None else (t_grid[1] - t_grid[0])).to(device)
        self.nx = nx
        self.ny = ny
        self.nt = nt
        self.rho = rho
        self.nu = nu

    def ic_residual_ns(self, u_flat):
        u = u_flat.view(self.nx, self.ny, self.nt)
        target = self.data[0][:, :, 0].to(u.device)
        return (u[:, :, 0] - target).flatten()

    def mass_residual_ns(self, u_flat):
        u = u_flat.view(self.nx, self.ny, self.nt)
        dx = self.dx.to(u.device)
        dy = self.dy.to(u.device)
        mass0 = u[:, :, 0].sum() * dx * dy
        mass_t = u.sum(dim=(0, 1)) * dx * dy
        return mass_t[1:] - mass0
    
    def full_residual_ns(self, u_flat):
        return torch.cat([self.ic_residual_ns(u_flat), self.mass_residual_ns(u_flat)], dim=0)


def make_grid(dims: tuple[int], device='cpu', start: float | tuple[float] = 0., end: float | tuple[float] = 1.):
    ndim = len(dims)
    if not isinstance(start, (tuple, list)):
        start = [start] * ndim
    if not isinstance(end, (tuple, list)):
        end = [end] * ndim
    if ndim == 1:
        return torch.linspace(start[0], end[0], dims[0], dtype=torch.float, device=device).unsqueeze(-1)
    xs = torch.meshgrid([
        torch.linspace(start[i], end[i], dims[i], dtype=torch.float, device=device)
        for i in range(ndim)
    ], indexing='ij')
    grid = torch.stack(xs, dim=-1).view(-1, ndim)
    return grid

# Relaxed constraint correction step in PCFM algorithm
def relaxed_penalty_constraint_interp_linear_detached(
    u0, u1_proj, v_flat, t, dt, hfunc, lam=1e-2, step_size=1e-2, num_steps=10, safe_clamp=1e-3
):
    t_prime = t + dt
    gamma = max(1 - t_prime, safe_clamp)
    hat_u = (1 - t_prime) * u0 + t_prime * u1_proj
    u = hat_u.detach().clone().requires_grad_(True)

    for _ in range(num_steps):
        u_ext = u + gamma * v_flat
        penalty = hfunc(u_ext).pow(2).sum()
        loss = (u - hat_u).pow(2).sum() + lam * penalty
        grad = torch.autograd.grad(loss, u)[0]
        u = (u - step_size * grad).detach().clone().requires_grad_(True)

    return u.detach()


# PCFM sampling with hard constraints satisfaction
def pcfm_sample(
    u_flat, v_flat, t, u0_flat, dt, hfunc,
    mode='root', newtonsteps=1, eps=1e-6,
    guided_interpolation = False, interpolation_params = {}
):
    ut1 = u_flat + (1.0 - t) * v_flat
    u_corr = ut1.clone()

    for _ in range(newtonsteps):
        res = hfunc(u_corr)
        J = compute_jacobian(hfunc, u_corr)
        JJt = J @ J.T
        rhs = res

        if mode == 'least_squares':
            delta = (ut1 - u_corr).unsqueeze(-1)
            rhs = J @ delta + res.unsqueeze(-1)
            rhs = rhs.squeeze(-1)

        lam = torch.linalg.solve(
            JJt + eps * torch.eye(JJt.shape[0], device=u_flat.device),
            rhs
        )
        u_corr = u_corr - J.T @ lam

    t_next = t + dt

    if guided_interpolation:
        if interpolation_params != {}:
            custom_lam = interpolation_params['custom_lam']
            step_size = interpolation_params['step_size']
            num_steps = interpolation_params['num_steps']
        else:
            custom_lam = 1e0
            step_size = 1e-2
            num_steps = 20
            
        ut_interp = relaxed_penalty_constraint_interp_linear_detached(
            u0=u0_flat,
            u1_proj=u_corr,
            v_flat=v_flat,
            t=t.item(),
            dt=dt,
            hfunc=hfunc,
            lam=custom_lam,
            step_size=step_size,
            num_steps=num_steps
        )
    else:
        ut_interp = (1.0 - t_next) * u0_flat + t_next * u_corr
    proj_vf = ((ut_interp - u_flat) / dt).detach()
    return proj_vf

# batched PCFM projection for 1D (nx, nt)
def pcfm_batched(ut, vf, t, u0, dt, hfunc, use_vmap=False, mode='root', newtonsteps=1, guided_interpolation=False, interpolation_params={}, eps=1e-6):
    B, nx, nt = ut.shape
    n = nx * nt

    def wrapped_project(u_flat, v_flat, u0_flat):
        return pcfm_sample(
            u_flat, v_flat, t, u0_flat, dt,
            hfunc=hfunc, mode=mode, newtonsteps=newtonsteps,
            guided_interpolation=guided_interpolation,
            interpolation_params=interpolation_params,
            eps=eps
        )

    u_flat = ut.view(B, n).detach().clone().requires_grad_(True)
    v_flat = vf.view(B, n)
    u0_flat = u0.view(B, n)

    if use_vmap:
        v_proj_flat = torch.vmap(wrapped_project)(u_flat, v_flat, u0_flat)
    else:
        v_proj_list = []
        for i in range(B):
            v_proj = wrapped_project(u_flat[i], v_flat[i], u0_flat[i])
            v_proj_list.append(v_proj)
        v_proj_flat = torch.stack(v_proj_list, dim=0)

    return v_proj_flat.view(B, nx, nt)

# batched PCFM projection for 2D (nx, ny, nt)
def pcfm_2d_batched(ut, vf, t, u0, dt, hfunc, mode='root', newtonsteps=1, guided_interpolation = True, interpolation_params={}, eps=1e-6):
    B, nx, ny, nt = ut.shape
    n = nx * ny * nt

    gc.collect()
    torch.cuda.empty_cache()

    def wrapped_project(u_flat, v_flat, u0_flat):
        return pcfm_sample(
            u_flat, v_flat, t, u0_flat, dt,
            hfunc=hfunc, mode=mode, newtonsteps=newtonsteps, 
            guided_interpolation=guided_interpolation, 
            interpolation_params=interpolation_params, 
            eps=eps
        )

    u_flat = ut.view(B, n).detach().clone().requires_grad_(True)
    v_flat = vf.view(B, n)
    u0_flat = u0.view(B, n)

    # v_proj_flat = vmap(wrapped_project)(u_flat, v_flat, u0_flat) 
    # prevent OOM: 
    v_proj_list = []
    for i in range(u_flat.shape[0]):
        v_proj = wrapped_project(u_flat[i], v_flat[i], u0_flat[i])
        v_proj_list.append(v_proj)
    v_proj_flat = torch.stack(v_proj_list, dim=0)
    return v_proj_flat.view(B, nx, ny, nt)


# Core sampler module to implement PCFM, vanilla flow matching, ECI, DiffusionPDE's guided sample, and D-Flow sampling
# uses a pretrained FFM model
class FFM_sampler:
    def __init__(self, model, gp):
        self.model = model
        self.gp = gp

    # our method, PCFM
    def pcfm_sample(self, u0, n_step, hfunc, mode='root', newtonsteps=1, eps=1e-6,
                    guided_interpolation=True, interpolation_params={}, use_vmap=False):
        dt = 1.0 / n_step
        u = u0.clone()
        for t in tqdm(torch.linspace(0, 1, n_step + 1, device=u0.device)[:-1], desc="PCFM sampling"):
            vf = self.model(t, u)
            v_proj = pcfm_batched(
                ut=u, vf=vf, t=t, u0=u0, dt=dt,
                hfunc=hfunc, mode=mode, newtonsteps=newtonsteps,
                guided_interpolation=guided_interpolation,
                interpolation_params=interpolation_params,
                eps=eps,
                use_vmap=use_vmap
            )
            u = u + dt * v_proj
        return u.detach()

    # FFM
    @torch.no_grad()
    def vanilla_sample(self, u0, n_step):
        dt = 1.0 / n_step
        u = u0.clone()
        for t in tqdm(torch.linspace(0, 1, n_step + 1, device=u0.device)[:-1], desc="Vanilla"):
            vf = self.model(t, u)
            u = u + dt * vf
        return u.detach()

    # ECI sampling
    @torch.no_grad()
    def eci_sample(self, u0, n_step, n_mix, resample_step, constraint):
        u = u0.clone()
        ts = torch.linspace(0, 1, n_step + 1, device=u0.device)
        cnt = 0
        dt = 1 / n_step
        grid = make_grid(u.shape[-2:], u.device)

        if resample_step == 0 or resample_step is None:
            resample_step = n_step * n_mix + 1

        for t in tqdm(ts[:-1], desc='ECI sampling'):
            for mix in range(n_mix):
                cnt += 1
                if cnt % resample_step == 0:
                    u0 = self.gp.sample(grid, u.shape[-2:], n_samples=u.shape[0])
                vf = self.model(t, u)
                u1 = u + vf * (1 - t)
                u1 = constraint.adjust(u1)
                if mix < n_mix - 1:
                    u = u1 * t + u0 * (1 - t)
                else:
                    u = u1 * (t + dt) + u0 * (1 - t - dt)
        return u.detach()

    # DiffusionPDE's
    # takes an IC and PINN loss (if known) on the extrapolated sample and updates the vector field 
    def guided_sample(self, u0, u1_true, mask, n_step, loss_fn, eta=2e2):
        device = u0.device
        u = u0.clone().to(device)
        u1_true = u1_true.to(device)
        mask = mask.to(device)
        ts = torch.linspace(0, 1, n_step + 1, device=device)

        for t in tqdm(ts[:-1], desc='DiffusionPDE sampling'):
            vf = self.model(t, u).detach()
            if t < ts[-2]:
                vf2 = self.model(t + 1 / n_step, u).detach()
                vf = (vf + vf2) / 2

            u.requires_grad_(True)
            u1_pred = u + vf * (1 - t)
            loss = loss_fn(u1_pred, u1_true, mask)
            # loss_obs = ((u1_pred - u1) * mask).square().sum() / mask.sum() / n_sample
            loss.backward()
            grad = u.grad
            u = u.detach() + vf / n_step - eta * grad
        return u.detach()

    # D-Flow 
    # optimizes the noise by differentiating through the flow matching ODE steps 
    def dflow_sample(self, u1_true, mask, n_sample, n_step, n_iter=20, lr=1e-1, loss_fn=None):
        device = u1_true.device
        mask = mask.to(device)
        grid = make_grid(u1_true.size()[1:], device)

        noise = self.gp.sample(grid, u1_true.size()[1:], n_samples=n_sample).to(device)
        noise.requires_grad_(True)

        ts = torch.linspace(0, 1, n_step + 1, device=device)

        def default_loss_fn(u_pred, u_true, mask):
            return ((u_pred - u_true) * mask).square().sum()
        loss_fn = loss_fn or default_loss_fn

        def euler_ffm(u):
            # for t in tqdm(ts[:-1], desc='Dflow sampling'):
            #     vf = self.model(t, u)
            #     u = u + vf / n_step
            print("DFlow sampling...")
            tspan = torch.tensor([0, 1.], device=device)
            u = odeint(self.model, u, tspan, method="euler", options = {"step_size":ts[1]-ts[0]})[-1]
            return u 

        def closure():
            gc.collect()
            torch.cuda.empty_cache()
            optimizer.zero_grad()
            u_pred = euler_ffm(noise)
            loss = loss_fn(u_pred, u1_true, mask)
            loss.backward()
            return loss

        optimizer = torch.optim.LBFGS([noise], max_iter=n_iter, lr=lr)
        optimizer.step(closure)

        with torch.no_grad():
            u_final = euler_ffm(noise)
        return u_final.detach()


# Core sampler module (for Navier-Stokes) to implement PCFM, vanilla flow matching, ECI, DiffusionPDE's guided sample, and D-Flow sampling
# uses a pretrained FFM model
class FFM_NS_sampler:
    def __init__(self, model):
        self.model = model

    # our method, PCFM
    def pcfm_sample(self, u0, n_step, hfunc, mode='root', newtonsteps=1, eps=1e-6,
                    guided_interpolation=True, interpolation_params={}):
        dt = 1.0 / n_step
        u = u0.clone()
        for t in tqdm(torch.linspace(0, 1, n_step + 1, device=u0.device)[:-1], desc="PCFM sampling"):
            vf = self.model(t, u)
            v_proj = pcfm_2d_batched(
                ut=u, vf=vf, t=t, u0=u0, dt=dt,
                hfunc=hfunc, mode=mode, newtonsteps=newtonsteps,
                guided_interpolation=guided_interpolation,
                interpolation_params=interpolation_params,
                eps=eps
            )
            u = u + dt * v_proj
        return u.detach()

    # FFM
    @torch.no_grad()
    def vanilla_sample(self, u0, n_step):
        dt = 1.0 / n_step
        u = u0.clone()
        for t in tqdm(torch.linspace(0, 1, n_step + 1, device=u0.device)[:-1], desc="Vanilla"):
            vf = self.model(t, u)
            u = u + dt * vf
        return u.detach()

    # ECI sampling 
    @torch.no_grad()
    def eci_sample(self, u0, n_step, n_mix, resample_step, constraint):
        u = u0.clone()
        ts = torch.linspace(0, 1, n_step + 1, device=u0.device)
        cnt = 0
        dt = 1 / n_step
        grid = make_grid(u.shape[-2:], u.device)

        if resample_step == 0 or resample_step is None:
            resample_step = n_step * n_mix + 1

        for t in tqdm(ts[:-1], desc='ECI sampling'):
            for mix in range(n_mix):
                cnt += 1
                if cnt % resample_step == 0:
                    u0 = torch.randn_like(u)
                vf = self.model(t, u)
                u1 = u + vf * (1 - t)
                u1 = constraint.adjust(u1)
                if mix < n_mix - 1:
                    u = u1 * t + u0 * (1 - t)
                else:
                    u = u1 * (t + dt) + u0 * (1 - t - dt)
        return u.detach()

    # DiffusionPDE's
    # takes an IC and PINN loss (if known) on the extrapolated sample and updates the vector field 
    def guided_sample(self, u0, u1_true, mask, n_step, loss_fn, eta=2e2):
        device = u0.device
        u = u0.clone().to(device)
        u1_true = u1_true.to(device)
        mask = mask.to(device)
        ts = torch.linspace(0, 1, n_step + 1, device=device)

        for t in tqdm(ts[:-1], desc='DiffusionPDE sampling'):
            vf = self.model(t, u).detach()
            if t < ts[-2]:
                vf2 = self.model(t + 1 / n_step, u).detach()
                vf = (vf + vf2) / 2

            u.requires_grad_(True)
            u1_pred = u + vf * (1 - t)
            loss = loss_fn(u1_pred, u1_true, mask)
            # loss_obs = ((u1_pred - u1) * mask).square().sum() / mask.sum() / n_sample
            loss.backward()
            grad = u.grad
            u = u.detach() + vf / n_step - eta * grad
        return u.detach()

    # D-Flow
    # optimizes the noise by differentiating through the flow matching ODE steps 
    def dflow_sample(self, u1_true, mask, n_sample, n_step, n_iter=20, lr=1e-1, loss_fn=None):
        device = u1_true.device
        mask = mask.to(device)
        grid = make_grid(u1_true.size()[1:], device)

        noise = torch.randn_like(u1_true) # randn for NS 
        noise.requires_grad_(True)

        ts = torch.linspace(0, 1, n_step + 1, device=device)
        
        def default_loss_fn(u_pred, u_true, mask):
            return ((u_pred - u_true) * mask).square().sum()
        loss_fn = loss_fn or default_loss_fn

        def euler_ffm(u):
            print("DFlow sampling...")
            tspan = torch.tensor([0, 1.], device=device)
            u = odeint(self.model, u, tspan, method="euler", options = {"step_size":ts[1]-ts[0]})[-1]
            return u 

        def closure():
            gc.collect()
            torch.cuda.empty_cache()
            optimizer.zero_grad()
            u_pred = euler_ffm(noise)
            loss = loss_fn(u_pred, u1_true, mask)
            loss.backward()
            return loss

        optimizer = torch.optim.LBFGS([noise], max_iter=n_iter, lr=lr)
        optimizer.step(closure)

        with torch.no_grad():
            u_final = euler_ffm(noise)
        return u_final.detach()

