# -*- coding: utf-8 -*-
"""Copy of sdeflow_equivalent_sdes.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1Tx_Yt90NRgHve--ocIXi6SGR-0ebwH0N
"""


import time
import numpy as np
import torch
import torch.nn as nn
import sys
import os

@torch.no_grad()
def EMstep(mu, delta , sigma , dW):
    if len(sigma.shape)>2 :
        dx = torch.einsum('bij, bj -> bi', sigma, dW )
    else :
        dx = sigma * dW
    return mu * delta + dx

### 2.0 Define Euler Maruyama method with a step size $\Delta t$
@torch.no_grad()
def euler_maruyama_sampler(sde, x_0, num_steps=1000, lmbd=0., keep_all_samples=True, include_t0=False, T_ = -1, norm_correction = False):
    """
    Euler Maruyama method with a step size delta
    """
    # init
    device = sde.T.device
    batch_size = x_0.size(0)
    ndim = x_0.dim()-1
    if T_ == -1 :
        T_ = sde.T.item()
    else:
        T_ = T_.item()
    delta = T_ / num_steps
    ts = torch.linspace(0, 1, num_steps + 1) * T_

    # sample
    x_t = x_0.detach().clone().to(device)
    if norm_correction:
        norm_x_0 = torch.norm(x_t,dim=1)
    if keep_all_samples :
        if (not include_t0) :
            xs = torch.zeros((x_0.shape[0],x_0.shape[1],num_steps),device='cpu')
        else :
            xs = torch.zeros((x_0.shape[0],x_0.shape[1],num_steps+1),device='cpu')
            xs[:,:,0]=x_t.clone().to('cpu')
    t = torch.zeros(batch_size, *([1]*ndim), device=device)
    with torch.no_grad():
        for i in range(num_steps):
            t.fill_(ts[i].item())
            mu = sde.mu(t, x_t, lmbd=lmbd)
            sigma = sde.sigma(t, x_t, lmbd=lmbd)
            x_t = x_t + EMstep(mu, delta , sigma , delta ** 0.5 * torch.randn_like(x_t)) # one step update of Euler Maruyama method with a step size delta
            if norm_correction:
                x_t = x_t * (norm_x_0/torch.norm(x_t,dim=1))[:,None]
            if keep_all_samples:
                xs[:,:,i+include_t0]=x_t.clone().to('cpu')
                
    if keep_all_samples:
        xs = torch.permute(xs, (2, 0, 1))
    else:
        xs=x_t.clone().to('cpu')
    
    return xs.to('cpu')

@torch.no_grad()
def heun_sampler(sde, x_0, num_steps=1000, lmbd=0., keep_all_samples=True, include_t0=False, T_=-1, norm_correction = False):
    """
    Heun method (Runge-Kutta 2) for SDEs in Stratonovich form.
    """
    # Initialization
    device = sde.T.device
    batch_size = x_0.size(0)
    ndim = x_0.dim() - 1
    if T_ == -1 :
        T_ = sde.T.item()
    else:
        T_ = T_.item()
    delta = T_ / num_steps
    ts = torch.linspace(0, 1, num_steps + 1) * T_

    # Sampling
    x_t = x_0.detach().clone().to(device)
    if norm_correction:
        norm_x_0 = torch.norm(x_t,dim=1)
    t = torch.zeros(batch_size, *([1] * ndim), device=device)
    if keep_all_samples :
        if (not include_t0) :
            xs = torch.zeros((x_0.shape[0],x_0.shape[1],num_steps),device='cpu')
        else :
            xs = torch.zeros((x_0.shape[0],x_0.shape[1],num_steps+1),device='cpu')
            xs[:,:,0]=x_t.clone().to('cpu')
        
    with torch.no_grad():
        for i in range(num_steps):
            t.fill_(ts[i].item())

            # Compute mu and sigma at the start of the interval
            mu_1 = sde.mu_Strato(t, x_t, lmbd=lmbd)
            sigma_1 = sde.sigma(t, x_t, lmbd=lmbd)
            dW = delta**0.5 * torch.randn_like(x_t)  # Wiener increment

            # Predictor step (Euler)
            x_predict = x_t + EMstep(mu_1, delta , sigma_1 , dW)
            # x_predict = x_t + delta * mu_1 + sigma_1 * dW

            # Corrector step
            mu_2 = sde.mu_Strato(t + delta, x_predict, lmbd=lmbd)
            sigma_2 = sde.sigma(t + delta, x_predict, lmbd=lmbd)

            # Average drift and diffusion terms
            # x_t = x_t + (delta / 2) * (mu_1 + mu_2) + (sigma_1 + sigma_2) * (dW / 2)
            x_t = x_t + EMstep(mu_1 + mu_2, delta / 2 , sigma_1 + sigma_2 , dW / 2)
            if norm_correction:
                x_t = x_t * (norm_x_0/torch.norm(x_t,dim=1))[:,None]

            if keep_all_samples:
                xs[:,:,i+include_t0]=x_t.clone().to('cpu')

    if keep_all_samples:
        xs = torch.permute(xs, (2, 0, 1))
    else:
        xs=x_t.clone().to('cpu')
    
    return xs.to('cpu')

@torch.no_grad()
def rk4_stratonovich_sampler(sde, x_0, num_steps=1000, lmbd=0., keep_all_samples=True, include_t0=False, T_=-1, norm_correction = False):
    """
    Runge-Kutta 4th order method for Stratonovich SDEs with skew-symmetric noise.
    
    Args:
        sde: an object with methods `mu(t, x, lmbd)` and `sigma(t, x, lmbd)`
            representing the drift and diffusion terms of the SDE.
        x_0: initial condition, torch tensor of shape (batch_size, ...)
        num_steps: number of time steps for the integration.
        lmbd: additional parameter passed to `mu` and `sigma`.
        keep_all_samples: whether to store all intermediate samples.
    
    Returns:
        List of samples if keep_all_samples is True, otherwise only the final state.
    """
    # Initialization
    device = sde.T.device
    batch_size = x_0.size(0)
    ndim = x_0.dim() - 1
    if T_ == -1 :
        T_ = sde.T.item()
    else:
        T_ = T_.item()
    delta = T_ / num_steps
    ts = torch.linspace(0, 1, num_steps + 1) * T_

    x_t = x_0.detach().clone().to(device)
    if norm_correction:
        norm_x_0 = torch.norm(x_t,dim=1)
    t = torch.zeros(batch_size, *([1] * ndim), device=device)
    if keep_all_samples :
        if (not include_t0) :
            xs = torch.zeros((x_0.shape[0],x_0.shape[1],num_steps),device='cpu')
        else :
            xs = torch.zeros((x_0.shape[0],x_0.shape[1],num_steps+1),device='cpu')
            xs[:,:,0]=x_t.clone().to('cpu')
    
    sqrt_delta = delta**0.5

    with torch.no_grad():
        for i in range(num_steps):
            t.fill_(ts[i].item())

            # Compute Wiener increments
            dW = sqrt_delta * torch.randn_like(x_t)
            
            # Stage 1
            mu_Strato_1 = sde.mu_Strato(t, x_t, lmbd=lmbd)
            sigma_1 = sde.sigma(t, x_t, lmbd=lmbd)
            K1 = EMstep(mu_Strato_1, delta, sigma_1, dW)
            # K1 = delta * mu_Strato_1 + sigma_1 * dW
            
            # Stage 2
            x_mid = x_t + K1 / 2
            mu_Strato_2 = sde.mu_Strato(t + delta / 2, x_mid, lmbd=lmbd)
            sigma_2 = sde.sigma(t + delta / 2, x_mid, lmbd=lmbd)
            K2 = EMstep(mu_Strato_2 , delta, sigma_2 , dW)
            
            # Stage 3
            x_mid = x_t + K2 / 2
            mu_Strato_3 = sde.mu_Strato(t + delta / 2, x_mid, lmbd=lmbd)
            sigma_3 = sde.sigma(t + delta / 2, x_mid, lmbd=lmbd)
            K3 = EMstep(mu_Strato_3 ,delta,  sigma_3 , dW)
            
            # Stage 4
            x_end = x_t + K3
            mu_Strato_4 = sde.mu_Strato(t + delta, x_end, lmbd=lmbd)
            sigma_4 = sde.sigma(t + delta, x_end, lmbd=lmbd)
            K4 = EMstep(mu_Strato_4, delta , sigma_4 , dW)
            
            # Combine stages (weighted sum)
            x_t = x_t + (K1 + 2 * K2 + 2 * K3 + K4) / 6
            if norm_correction:
                x_t = x_t * (norm_x_0/torch.norm(x_t,dim=1))[:,None]

            if keep_all_samples:
                xs[:,:,i+include_t0]=x_t.clone().to('cpu')

    if keep_all_samples:
        xs = torch.permute(xs, (2, 0, 1))
    else:
        xs=x_t.clone().to('cpu')
    
    return xs.to('cpu')
