#!/usr/bin/env python
import os
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import argparse
from torch.utils.data import DataLoader, TensorDataset
from data import (
    generate_jump_diffusion_paths,
    generate_gbm_paths,
    generate_ou_paths,
    generate_cir_paths,
    generate_gamma_process_paths,
    generate_physical_sde_paths,
    generate_nonlinear_sde_paths,
    generate_power_law_volatility_paths,
    generate_polynomial_drift_paths
)

# Command-line argument parsing
def parse_args():
    parser = argparse.ArgumentParser(description='Train SDE path generator model')
    parser.add_argument('--data_type', type=str, default='jump_diffusion',
                        choices=['jump_diffusion', 'gbm', 'ou', 'cir', 'gamma', 
                                'physical', 'nonlinear', 'power_law', 'polynomial'],
                        help='Choose data generation function')
    parser.add_argument('--n_paths', type=int, default=10000, help='Number of sample paths to generate')
    parser.add_argument('--epochs', type=int, default=2000, help='Number of training epochs')
    parser.add_argument('--batch_size', type=int, default=1024, help='Batch size')
    parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate')
    parser.add_argument('--seed', type=int, default=42, help='Random seed')
    parser.add_argument('--viz_interval', type=int, default=50, help='Visualization interval (epochs)')
    return parser.parse_args()

# Set random seed for reproducibility
def set_seed(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Ensure output directory exists
def ensure_dir(directory):
    if not os.path.exists(directory):
        os.makedirs(directory)
        
def create_save_directories(data_generator_name='jump_diffusion'):
    """Create directory structure for saving data and figures"""
    base_dir = f'save_{data_generator_name}'
    ensure_dir(base_dir)
    data_dir = os.path.join(base_dir, 'data')
    ensure_dir(data_dir)
    pic_dir = os.path.join(base_dir, 'pics')
    ensure_dir(pic_dir)
    return base_dir, data_dir, pic_dir

def compute_mmd(x, y):
    """Compute MMD (Wasserstein distance) between two distributions"""
    x_sorted = np.sort(np.array(x).flatten())
    y_sorted = np.sort(np.array(y).flatten())
    if len(x_sorted) != len(y_sorted):
        min_length = min(len(x_sorted), len(y_sorted))
        x_indices = np.linspace(0, len(x_sorted)-1, min_length).astype(int)
        y_indices = np.linspace(0, len(y_sorted)-1, min_length).astype(int)
        x_sorted = x_sorted[x_indices]
        y_sorted = y_sorted[y_indices]
    return np.mean(np.abs(x_sorted - y_sorted))

class DriftNet(nn.Module):
    """Neural network for drift term f(x, t)"""
    def __init__(self, hidden_dims=[64, 128, 64]):
        super(DriftNet, self).__init__()
        layers = []
        input_dim = 2  # (x, t)
        for h_dim in hidden_dims:
            layers.append(nn.Linear(input_dim, h_dim))
            layers.append(nn.SiLU())  # Swish activation
            input_dim = h_dim
        layers.append(nn.Linear(input_dim, 1))
        self.net = nn.Sequential(*layers)
    def forward(self, x, t):
        inputs = torch.cat([x.view(-1, 1), t.view(-1, 1)], dim=1)
        return self.net(inputs)

class DiffusionNet(nn.Module):
    """Neural network for diffusion term g(x, t)"""
    def __init__(self, hidden_dims=[64, 128, 64]):
        super(DiffusionNet, self).__init__()
        layers = []
        input_dim = 2  # (x, t)
        for h_dim in hidden_dims:
            layers.append(nn.Linear(input_dim, h_dim))
            layers.append(nn.SiLU())
            input_dim = h_dim
        layers.append(nn.Linear(input_dim, 1))
        layers.append(nn.Softplus())  # Ensure positivity
        self.net = nn.Sequential(*layers)
    def forward(self, x, t):
        inputs = torch.cat([x.view(-1, 1), t.view(-1, 1)], dim=1)
        return self.net(inputs)

class EnhancedFlexibleDWGenerator(nn.Module):
    def __init__(self, input_dim=1, hidden_dims=[256, 512, 256], base_dropout=0.2):
        super(EnhancedFlexibleDWGenerator, self).__init__()
        self.base_network = self._build_residual_network(input_dim, hidden_dims, dropout_p=base_dropout)
        self.multi_scale_networks = nn.ModuleList([
            self._build_residual_network(hidden_dims[-1], [128, 128], dropout_p=0.1),
            self._build_residual_network(hidden_dims[-1], [128, 128], dropout_p=0.5),
            self._build_residual_network(hidden_dims[-1], [128, 128], dropout_p=0.8)
        ])
        self.mixing_network = self._build_residual_network(hidden_dims[-1], [128, 64], dropout_p=0.3)
        self.mixing_weights = nn.Linear(64, 3)
        self.output_mean = nn.Linear(128*3, 1)
        self.output_scale = nn.Linear(128*3, 1)
        self.dt_scale = nn.Parameter(torch.ones(1))
    def _build_residual_network(self, input_dim, hidden_dims, dropout_p):
        layers = []
        current_dim = input_dim
        for h_dim in hidden_dims:
            block = nn.Sequential(
                nn.Linear(current_dim, h_dim),
                nn.BatchNorm1d(h_dim),
                nn.Dropout(dropout_p),
                nn.SiLU()
            )
            layers.append(block)
            if current_dim == h_dim:
                layers.append(ResidualConnection())
            current_dim = h_dim
        return nn.Sequential(*layers)
    def forward(self, dt):
        base_features = self.base_network(dt)
        scale_outputs = [network(base_features) for network in self.multi_scale_networks]
        mixing_features = self.mixing_network(base_features)
        mixing_weights = torch.softmax(self.mixing_weights(mixing_features), dim=1)
        combined_features = torch.cat(scale_outputs, dim=1)
        mean = self.output_mean(combined_features)
        scale = torch.abs(self.output_scale(combined_features)) + 1e-6
        if self.training:
            noise = torch.randn_like(mean) * scale + mean
            output = noise * torch.sqrt(dt) * self.dt_scale
        else:
            output = (torch.randn_like(mean) * scale + mean) * torch.sqrt(dt) * self.dt_scale
        return output

class ResidualConnection(nn.Module):
    def forward(self, x):
        return x

class NeuralSDEModel(nn.Module):
    def __init__(self, params, dw_generator, drift_net, diffusion_net):
        super(NeuralSDEModel, self).__init__()
        self.params = params
        self.dw_generator = dw_generator
        self.drift_net = drift_net
        self.diffusion_net = diffusion_net
    def forward(self, x0, dt, steps=None):
        batch_size = x0.size(0)
        if steps is None:
            steps = self.params.n_steps
        X = torch.zeros(batch_size, steps + 1, device=x0.device)
        X[:, 0] = x0.squeeze()
        dt_tensor = torch.ones(batch_size, 1, device=x0.device) * dt
        for i in range(steps):
            t = torch.ones_like(X[:, i].view(-1, 1)) * i * dt
            drift = self.drift_net(X[:, i].view(-1, 1), t)
            diffusion = self.diffusion_net(X[:, i].view(-1, 1), t)
            dW = self.dw_generator(dt_tensor)
            X[:, i+1] = X[:, i] + drift.squeeze() * dt + diffusion.squeeze() * dW.squeeze()
        return X

def wasserstein_distance(y_pred, y_true):
    """Wasserstein distance between two 1D empirical distributions"""
    y_pred_sorted, _ = torch.sort(y_pred)
    y_true_sorted, _ = torch.sort(y_true)
    return torch.mean(torch.abs(y_pred_sorted - y_true_sorted))



