import os
import torch
import numpy as np

def set_seed(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def ensure_dir(directory):
    if not os.path.exists(directory):
        os.makedirs(directory)

###### 1. Geometric Brownian Motion (GBM) ######
class GBMParams:
    def __init__(self, mu=0.05, sigma=0.2, x0=1.0, T=1.0, n_steps=100):
        self.mu = mu          # drift rate
        self.sigma = sigma    # volatility
        self.x0 = x0          # initial value
        self.T = T            # total time
        self.n_steps = n_steps # number of steps
        self.dt = T / n_steps  # step size

_default_gbm_params = GBMParams(
    mu=0.05, sigma=0.2, x0=1.0, T=1.0, n_steps=100
)

def generate_gbm_paths(n_paths=10000, device=None):
    """
    Generate paths for Geometric Brownian Motion:
    dX_t = mu * X_t * dt + sigma * X_t * dW_t
    """
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    params = _default_gbm_params
    X = torch.ones(n_paths, params.n_steps + 1, device=device) * params.x0
    dt = params.dt
    sqrt_dt = np.sqrt(dt)
    for i in range(params.n_steps):
        dW = torch.randn(n_paths, 1, device=device) * sqrt_dt
        X[:, i+1] = X[:, i] + params.mu * X[:, i] * dt + params.sigma * X[:, i] * dW.squeeze()
    return X, dt, dW, params

###### 2. Ornstein-Uhlenbeck (OU) Process ######
class OUParams:
    def __init__(self, kappa=1.0, alpha=1.0, sigma=0.3, x0=1.0, T=1.0, n_steps=100):
        self.kappa = kappa    # mean-reversion speed
        self.alpha = alpha    # long-term mean
        self.sigma = sigma    # volatility
        self.x0 = x0          # initial value
        self.T = T            # total time
        self.n_steps = n_steps # number of steps
        self.dt = T / n_steps  # step size

_default_ou_params = OUParams(
    kappa=1.0, alpha=1.0, sigma=0.3, x0=1.0, T=1.0, n_steps=100
)

def generate_ou_paths(n_paths=10000, device=None):
    """
    Generate paths for Ornstein-Uhlenbeck Process:
    dX_t = kappa * (alpha - X_t) * dt + sigma * dW_t
    """
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    params = _default_ou_params
    X = torch.ones(n_paths, params.n_steps + 1, device=device) * params.x0
    dt = params.dt
    sqrt_dt = np.sqrt(dt)
    for i in range(params.n_steps):
        dW = torch.randn(n_paths, 1, device=device) * sqrt_dt
        X[:, i+1] = X[:, i] + params.kappa * (params.alpha - X[:, i]) * dt + params.sigma * dW.squeeze()
    return X, dt, dW, params

###### 3. Cox-Ingersoll-Ross (CIR) Process ######
class CIRParams:
    def __init__(self, kappa=0.5, alpha=1.0, sigma=0.3, x0=1.0, T=1.0, n_steps=100):
        self.kappa = kappa    # mean-reversion speed
        self.alpha = alpha    # long-term mean
        self.sigma = sigma    # volatility
        self.x0 = x0          # initial value
        self.T = T            # total time
        self.n_steps = n_steps # number of steps
        self.dt = T / n_steps  # step size

_default_cir_params = CIRParams(
    kappa=0.5, alpha=1.0, sigma=0.3, x0=1.0, T=1.0, n_steps=100
)

def generate_cir_paths(n_paths=10000, device=None):
    """
    Generate paths for Cox-Ingersoll-Ross Process:
    dX_t = kappa * (alpha - X_t) * dt + sigma * sqrt(X_t) * dW_t
    """
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    params = _default_cir_params
    X = torch.ones(n_paths, params.n_steps + 1, device=device) * params.x0
    dt = params.dt
    sqrt_dt = np.sqrt(dt)
    for i in range(params.n_steps):
        dW = torch.randn(n_paths, 1, device=device) * sqrt_dt
        X_sqrt = torch.sqrt(torch.abs(X[:, i]) + 1e-8)
        X[:, i+1] = X[:, i] + params.kappa * (params.alpha - X[:, i]) * dt + params.sigma * X_sqrt * dW.squeeze()
        X[:, i+1] = torch.clamp(X[:, i+1], min=0.0)
    return X, dt, dW, params

###### 4. Generalized-Gamma Process ######
class GammaProcessParams:
    def __init__(self, a_coef=0.1, b_coef=0.2, c_coef=0.3, jump_rate=3.0, 
                 jump_mean=0.1, jump_std=0.2, x0=1.0, T=1.0, n_steps=100):
        self.a_coef = a_coef
        self.b_coef = b_coef
        self.c_coef = c_coef
        self.jump_rate = jump_rate
        self.jump_mean = jump_mean
        self.jump_std = jump_std
        self.x0 = x0
        self.T = T
        self.n_steps = n_steps
        self.dt = T / n_steps

_default_gamma_params = GammaProcessParams(
    a_coef=0.1, b_coef=0.2, c_coef=0.3, jump_rate=3.0, 
    jump_mean=0.1, jump_std=0.2, x0=1.0, T=1.0, n_steps=100
)

def generate_gamma_process_paths(n_paths=10000, device=None):
    """
    Generate paths for Generalized-Gamma Process:
    dX_t = a(X_t, t) * dt + b(X_t, t) * dW_t + c(X_t, t) * dN_t
    """
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    params = _default_gamma_params
    X = torch.ones(n_paths, params.n_steps + 1, device=device) * params.x0
    dt = params.dt
    sqrt_dt = np.sqrt(dt)
    for i in range(params.n_steps):
        drift = params.a_coef * X[:, i]
        diffusion = params.b_coef * X[:, i]
        dW = torch.randn(n_paths, 1, device=device) * sqrt_dt
        jump_prob = 1 - torch.exp(torch.tensor(-params.jump_rate * dt, device=device))
        jump_occurs = torch.rand(n_paths, 1, device=device) < jump_prob
        jump_size = torch.randn(n_paths, 1, device=device) * params.jump_std + params.jump_mean
        dJ = jump_occurs.float() * jump_size
        X[:, i+1] = X[:, i] + drift * dt + diffusion * dW.squeeze() + params.c_coef * X[:, i] * dJ.squeeze()
    return X, dt, dW, params

###### 5. Physically Motivated SDE ######
class PhysicalSDEParams:
    def __init__(self, eta=0.7, nu=0.3, x0=1.0, T=1.0, n_steps=100):
        self.eta = eta
        self.nu = nu
        self.x0 = x0
        self.T = T
        self.n_steps = n_steps
        self.dt = T / n_steps

_default_physical_params = PhysicalSDEParams(
    eta=0.7, nu=0.3, x0=1.0, T=1.0, n_steps=100
)

def generate_physical_sde_paths(n_paths=10000, device=None):
    """
    Generate paths for a physically motivated SDE:
    dx = (η−ν/2)x^(2η−1) dt + x^η dW
    """
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    params = _default_physical_params
    X = torch.ones(n_paths, params.n_steps + 1, device=device) * params.x0
    dt = params.dt
    sqrt_dt = np.sqrt(dt)
    for i in range(params.n_steps):
        power_2eta_minus_1 = torch.pow(X[:, i], 2 * params.eta - 1)
        power_eta = torch.pow(X[:, i], params.eta)
        drift_coef = params.eta - params.nu / 2
        drift = drift_coef * power_2eta_minus_1
        dW = torch.randn(n_paths, 1, device=device) * sqrt_dt
        X[:, i+1] = X[:, i] + drift * dt + power_eta * dW.squeeze()
        X[:, i+1] = torch.clamp(X[:, i+1], min=1e-8)
    return X, dt, dW, params

###### 6. Nonlinear SDE ######
class NonlinearSDEParams:
    def __init__(self, kappa=0.5, sigma=0.3, x0=1.0, T=1.0, n_steps=100):
        self.kappa = kappa
        self.sigma = sigma
        self.x0 = x0
        self.T = T
        self.n_steps = n_steps
        self.dt = T / n_steps

_default_nonlinear_params = NonlinearSDEParams(
    kappa=0.5, sigma=0.3, x0=1.0, T=1.0, n_steps=100
)

def generate_nonlinear_sde_paths(n_paths=10000, device=None):
    """
    Generate paths for a nonlinear SDE:
    dX_t = X_t(kappa - (sigma^2 - kappa * X_t))dt + sigma * X_t^(3/2) * dW_t
    """
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    params = _default_nonlinear_params
    X = torch.ones(n_paths, params.n_steps + 1, device=device) * params.x0
    dt = params.dt
    sqrt_dt = np.sqrt(dt)
    for i in range(params.n_steps):
        drift = X[:, i] * (params.kappa - (params.sigma**2 - params.kappa * X[:, i]))
        diffusion = params.sigma * torch.pow(X[:, i], 1.5)
        dW = torch.randn(n_paths, 1, device=device) * sqrt_dt
        X[:, i+1] = X[:, i] + drift * dt + diffusion * dW.squeeze()
        X[:, i+1] = torch.clamp(X[:, i+1], min=1e-8)
    return X, dt, dW, params

###### 7. Power-law Volatility Model ######
class PowerLawVolatilityParams:
    def __init__(self, kappa=0.5, alpha=1.0, sigma=0.3, p=0.5, x0=1.0, T=1.0, n_steps=100):
        self.kappa = kappa
        self.alpha = alpha
        self.sigma = sigma
        self.p = p
        self.x0 = x0
        self.T = T
        self.n_steps = n_steps
        self.dt = T / n_steps

_default_power_law_params = PowerLawVolatilityParams(
    kappa=0.5, alpha=1.0, sigma=0.3, p=0.5, x0=1.0, T=1.0, n_steps=100
)

def generate_power_law_volatility_paths(n_paths=10000, device=None):
    """
    Generate paths for power-law volatility model:
    dX_t = kappa * (alpha - X_t) * dt + sigma * X_t^p * dW_t
    """
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    params = _default_power_law_params
    X = torch.ones(n_paths, params.n_steps + 1, device=device) * params.x0
    dt = params.dt
    sqrt_dt = np.sqrt(dt)
    for i in range(params.n_steps):
        drift = params.kappa * (params.alpha - X[:, i])
        X_abs = torch.abs(X[:, i]) + 1e-8
        diffusion = params.sigma * torch.pow(X_abs, params.p)
        dW = torch.randn(n_paths, 1, device=device) * sqrt_dt
        X[:, i+1] = X[:, i] + drift * dt + diffusion * dW.squeeze()
        X[:, i+1] = torch.clamp(X[:, i+1], min=1e-8)
    return X, dt, dW, params

###### 8. Polynomial Drift Model ######
class PolynomialDriftParams:
    def __init__(self, alpha_minus1=-0.1, alpha0=0.1, alpha1=0.2, alpha2=-0.05, 
                 sigma=0.3, x0=1.0, T=1.0, n_steps=100):
        self.alpha_minus1 = alpha_minus1
        self.alpha0 = alpha0
        self.alpha1 = alpha1
        self.alpha2 = alpha2
        self.sigma = sigma
        self.x0 = x0
        self.T = T
        self.n_steps = n_steps
        self.dt = T / n_steps

_default_polynomial_params = PolynomialDriftParams(
    alpha_minus1=-0.1, alpha0=0.1, alpha1=0.2, alpha2=-0.05, 
    sigma=0.3, x0=1.0, T=1.0, n_steps=100
)

def generate_polynomial_drift_paths(n_paths=10000, device=None):
    """
    Generate paths for polynomial drift model:
    dX_t = (alpha_(-1)*X_t^(-1) + alpha0 + alpha1*X_t + alpha2*X_t^2) * dt + sigma * X_t^(3/2) * dW_t
    """
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    params = _default_polynomial_params
    X = torch.ones(n_paths, params.n_steps + 1, device=device) * params.x0
    dt = params.dt
    sqrt_dt = np.sqrt(dt)
    for i in range(params.n_steps):
        X_abs = torch.abs(X[:, i]) + 1e-8
        drift = (params.alpha_minus1 / X_abs + params.alpha0 +
                 params.alpha1 * X[:, i] + params.alpha2 * X[:, i]**2)
        diffusion = params.sigma * torch.pow(X_abs, 1.5)
        dW = torch.randn(n_paths, 1, device=device) * sqrt_dt
        X[:, i+1] = X[:, i] + drift * dt + diffusion * dW.squeeze()
        X[:, i+1] = torch.clamp(X[:, i+1], min=1e-8)
    return X, dt, dW, params

###### 9. Jump Diffusion Model ######
class JumpSDEParams:
    def __init__(self, mu=0.05, sigma=0.2, x0=1.0, T=1.0, n_steps=100, 
                 jump_rate=5.0, jump_mean=-0.1, jump_std=0.8):
        self.mu = mu
        self.sigma = sigma
        self.x0 = x0
        self.T = T
        self.n_steps = n_steps
        self.dt = T / n_steps
        self.jump_rate = jump_rate
        self.jump_mean = jump_mean
        self.jump_std = jump_std

_default_params = JumpSDEParams(
    mu=0.05, sigma=0.2, x0=1.0, T=1.0, n_steps=100,
    jump_rate=5.0, jump_mean=-0.1, jump_std=0.8
)

def generate_jump_diffusion_paths(n_paths=10000, device=None):
    """
    Generate paths for a jump-diffusion SDE model:
    dX_t = mu * X_t * dt + sigma * X_t * dW_t + jumps
    """
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    params = _default_params
    X = torch.ones(n_paths, params.n_steps + 1, device=device) * params.x0
    dt = params.dt
    sqrt_dt = np.sqrt(dt)
    for i in range(params.n_steps):
        dW_normal = torch.randn(n_paths, 1, device=device) * sqrt_dt
        jump_prob = 1 - torch.exp(torch.tensor(-params.jump_rate * dt, device=device))
        jump_occurs = torch.rand(n_paths, 1, device=device) < jump_prob
        jump_size = torch.randn(n_paths, 1, device=device) * params.jump_std + params.jump_mean
        dJ = jump_occurs.float() * jump_size
        dW = dW_normal + dJ
        X[:, i+1] = X[:, i] + params.mu * X[:, i] * dt + params.sigma * X[:, i] * dW.squeeze()
    return X, dt, dW, params
