import numpy as np
import torch
import matplotlib.pyplot as plt
import os
import warnings
from scipy.stats import gaussian_kde
import scipy.stats

def generate_hybrid_jump_diffusion_data(
    num_samples,
    seq_len=150,
    input_len=100,
    output_len=50,
    dt=1.0,
    mu0_1=0.1, beta_1=0.001,
    sigma0_1=0.01,
    lambda_1=0.05,
    jump_size_1=0.005,
    mu0_2=0.2, beta_2=0.002,
    sigma0_2=0.002,
    lambda_2=0.8,
    jump_size_2=0.001,
    x0_base=40.0,
    x0_perturb=0.1,
    switch_probability=0.3,
    seed=None
):
    if seed is not None:
        np.random.seed(seed)
        torch.manual_seed(seed)
    
    seq_len = input_len + output_len
    switch_point = input_len
    data = np.zeros((num_samples, seq_len, 1), dtype=np.float32)
    timestamps = np.zeros((num_samples, seq_len + 1), dtype=np.float32)

    for i in range(num_samples):
        e = (2 * x0_perturb) * np.random.rand() - x0_perturb
        x0 = x0_base * (1 + e)
        x = np.zeros(seq_len, dtype=np.float32)
        x[0] = x0

        for t in range(1, switch_point):
            dW = np.random.normal(0.0, np.sqrt(dt))
            dN = np.random.poisson(lambda_1 * dt)
            jump = dN * np.random.normal(jump_size_1, jump_size_1/2) if dN > 0 else 0
            drift = (mu0_1 + beta_1 * t * dt) * x[t-1] * dt
            diffusion = sigma0_1 * np.sqrt(abs(x[t-1])) * dW
            jump_term = x[t-1] * jump
            x[t] = x[t-1] + drift + diffusion + jump_term
            x[t] = max(x[t], 0)
        
        use_params2 = np.random.rand() < switch_probability
        if use_params2:
            mu0, beta, sigma0, lambda_rate, jump_size = mu0_2, beta_2, sigma0_2, lambda_2, jump_size_2
        else:
            mu0, beta, sigma0, lambda_rate, jump_size = mu0_1, beta_1, sigma0_1, lambda_1, jump_size_1
        
        for t in range(switch_point, seq_len):
            dW = np.random.normal(0.0, np.sqrt(dt))
            dN = np.random.poisson(lambda_rate * dt)
            jump = dN * np.random.normal(jump_size, jump_size/2) if dN > 0 else 0
            drift = (mu0 + beta * t * dt) * x[t-1] * dt
            diffusion = sigma0 * np.sqrt(abs(x[t-1])) * dW
            jump_term = x[t-1] * jump
            x[t] = x[t-1] + drift + diffusion + jump_term
            x[t] = max(x[t], 0)
        
        data[i, :, 0] = x
        timestamps[i] = np.arange(seq_len + 1, dtype=np.float32) * dt

    def normalize_to_range(data, target_min=0, target_max=800):
        normalized_data = np.zeros_like(data)
        for i in range(data.shape[2]):
            feature_data = data[:, :, i]
            feature_min = np.min(feature_data)
            feature_max = np.max(feature_data)
            normalized = (feature_data - feature_min) / (feature_max - feature_min)
            normalized_data[:, :, i] = normalized * (target_max - target_min) + target_min
        return normalized_data.astype(np.float32)

    data = normalize_to_range(data)
    data_torch = torch.tensor(data, dtype=torch.float32)
    timestamps_torch = torch.tensor(timestamps, dtype=torch.float32)
    return data_torch, timestamps_torch

def visualize_jump_diffusion_data():
    """
    Visualize jump diffusion process data:
    1. Plot 10 sample trajectories
    2. Plot the final time point distribution histogram for 3000 samples
    """
    import seaborn as sns

    save_dir = './figures'
    os.makedirs(save_dir, exist_ok=True)

    data_10, timestamps_10 = generate_hybrid_jump_diffusion_data(
        num_samples=10,
        seq_len=150,
        input_len=100,
        output_len=50,
        dt=1.0,
        x0_base=40.0,
        x0_perturb=0.1,
        seed=42
    )
    
    data_3000, _ = generate_hybrid_jump_diffusion_data(
        num_samples=3000,
        seq_len=150,
        input_len=100,
        output_len=50,
        dt=1.0,
        x0_base=40.0,
        x0_perturb=0.1,
        seed=42
    )
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
    for i in range(10):
        ax1.plot(timestamps_10[i, :-1].numpy(), 
                 data_10[i, :, 0].numpy(), 
                 alpha=0.7,
                 label=f'Sample {i+1}')
    ax1.axvline(x=100, color='r', linestyle='--', alpha=0.5, label='Switch point')
    ax1.set_title('Jump Diffusion Trajectories (10 samples)')
    ax1.set_xlabel('Time')
    ax1.set_ylabel('Normalized Value')
    ax1.legend()
    ax1.grid(True)
    
    final_values = data_3000[:, -1, 0].numpy()
    sns.histplot(data=final_values, ax=ax2, kde=True, bins=50)
    ax2.set_title('Final Time Distribution (3000 samples)')
    ax2.set_xlabel('Normalized Value')
    ax2.set_ylabel('Frequency')
    ax2.grid(True)
    plt.tight_layout()
    
    save_path = os.path.join(save_dir, 'jump_diffusion_visualization.png')
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    print(f"Figure saved to: {save_path}")
    plt.close()

if __name__ == "__main__":
    visualize_jump_diffusion_data()
