import os
import random
import numpy as np
import matplotlib.pyplot as plt
import json
from scipy.stats import entropy

import torch
import torch.nn as nn
import torch.optim as optim
import torchcde
import torchsde
from torch.utils.data import Dataset, DataLoader

def seed_everything(seed):
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.random.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

def gaussian_kernel(x, y, sigma=1.0):
    x_size = x.size(0)
    y_size = y.size(0)
    dim = x.size(1)
    x = x.unsqueeze(1)
    y = y.unsqueeze(0)
    tiled_x = x.expand(x_size, y_size, dim)
    tiled_y = y.expand(x_size, y_size, dim)
    kernel_input = (tiled_x - tiled_y).pow(2).sum(2) / float(dim)
    return torch.exp(-kernel_input / (2 * sigma**2))

def compute_mmd(x, y):
    x_kernel = gaussian_kernel(x, x)
    y_kernel = gaussian_kernel(y, y)
    xy_kernel = gaussian_kernel(x, y)
    mmd = x_kernel.mean() + y_kernel.mean() - 2 * xy_kernel.mean()
    return mmd

def ou_process(T, N, theta, mu, sigma, X0):
    dt = T / N
    t = np.linspace(0, T, N)
    X = np.zeros(N)
    X[0] = X0
    for i in range(1, N):
        dW = np.random.normal(0, np.sqrt(dt))
        X[i] = X[i-1] + theta * (mu - X[i-1]) * dt + sigma * dW
    return t, X

def generate_data(num_samples, T, N, theta, mu, sigma, X0):
    data_list = []
    for _ in range(num_samples):
        t, X = ou_process(T, N, theta, mu, sigma, X0)
        data_list.append([t, X])
    total_data = torch.Tensor(np.array(data_list))
    total_data = total_data.permute(0, 2, 1)
    max_len = total_data.shape[1]
    times = torch.linspace(0, 1, max_len)
    coeffs = torchcde.hermite_cubic_coefficients_with_backward_differences(total_data, times)
    return total_data, coeffs, times

class OU_Dataset(Dataset):
    def __init__(self, data, coeffs):
        self.data = data
        self.coeffs = coeffs
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        return (
            self.data[idx, ...],
            self.coeffs[idx, ...],
        )

def split_data(data, coeffs, train_ratio=0.8):
    total_size = len(data)
    train_size = int(total_size * train_ratio)
    train_idx = np.random.choice(range(total_size), train_size, replace=False)
    test_idx = np.array([i for i in range(total_size) if i not in train_idx])
    train_data = data[train_idx, ...]
    test_data = data[test_idx, ...]
    train_coeffs = coeffs[train_idx, ...]
    test_coeffs = coeffs[test_idx, ...]
    return train_data, train_coeffs, test_data, test_coeffs

def create_data_loaders(train_data, train_coeffs, test_data, test_coeffs, batch_size=16):
    train_dataset = OU_Dataset(train_data, train_coeffs)
    test_dataset = OU_Dataset(test_data, test_coeffs)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    return train_loader, test_loader

class LipSwish(nn.Module):
    def forward(self, x):
        return 0.909 * torch.nn.functional.silu(x)

class MLP(nn.Module):
    def __init__(self, in_size, out_size, hidden_dim, num_layers, tanh=False, activation='lipswish'):
        super().__init__()
        if activation == 'lipswish':
            activation_fn = LipSwish()
        else:
            activation_fn = nn.ReLU()
        model = [nn.Linear(in_size, hidden_dim), activation_fn]
        for _ in range(num_layers - 1):
            model.append(nn.Linear(hidden_dim, hidden_dim))
            model.append(activation_fn)
        model.append(nn.Linear(hidden_dim, out_size))
        if tanh:
            model.append(nn.Tanh())
        self._model = nn.Sequential(*model)
    def forward(self, x):
        return self._model(x)

class NeuralGSDEFunc(nn.Module):
    def __init__(self, input_dim, hidden_dim, hidden_hidden_dim, num_layers, activation='lipswish'):
        super(NeuralGSDEFunc, self).__init__()
        self.sde_type = "ito"
        self.noise_type = "diagonal"
        self.linear_in = nn.Linear(hidden_dim + 1, hidden_dim)
        self.linear_X = nn.Linear(input_dim, hidden_dim)
        self.emb = nn.Linear(hidden_dim*2, hidden_dim)
        self.f_net = MLP(hidden_dim, hidden_dim, hidden_hidden_dim, num_layers, activation=activation)
        self.linear_out = nn.Linear(hidden_dim, hidden_dim)
        self.noise_in = nn.Linear(1, hidden_dim)
        self.g_net = MLP(hidden_dim, hidden_dim, hidden_hidden_dim, num_layers, activation=activation)
    def set_X(self, coeffs, times):
        self.coeffs = coeffs
        self.times = times
        self.X = torchcde.CubicSpline(self.coeffs, self.times)
    def f(self, t, y):
        Xt = self.X.evaluate(t)
        Xt = self.linear_X(Xt)
        if t.dim() == 0:
            t = torch.full_like(y[:, 0], fill_value=t).unsqueeze(-1)
        yy = self.linear_in(torch.cat((t, y), dim=-1))
        z = self.emb(torch.cat([yy, Xt], dim=-1))
        z = self.f_net(z) * y
        return self.linear_out(z)
    def g(self, t, y):
        if t.dim() == 0:
            t = torch.full_like(y[:, 0], fill_value=t).unsqueeze(-1)
        tt = self.noise_in(t)
        return self.g_net(tt) * y

class NDE_model(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers, activation='lipswish', vector_field=None):
        super(NDE_model, self).__init__()
        self.func = vector_field(input_dim, hidden_dim, hidden_dim, num_layers, activation=activation)
        self.initial = nn.Linear(input_dim, hidden_dim)
        self.decoder = nn.Linear(hidden_dim, output_dim)
    def forward(self, coeffs, times):
        self.func.set_X(coeffs, times)
        y0 = self.func.X.evaluate(times)
        y0 = self.initial(y0)[:,0,:]
        z = torchsde.sdeint(sde=self.func,
                            y0=y0,
                            ts=times,
                            dt=0.05,
                            method='euler')
        z = z.permute(1,0,2)
        return self.decoder(z)

def save_results_and_plot(all_trues, all_preds, output_dir="results"):
    os.makedirs(output_dir, exist_ok=True)
    last_time_point = -1
    true_last = all_trues[:, last_time_point]
    pred_last = all_preds[:, last_time_point]
    mse = torch.nn.functional.mse_loss(pred_last, true_last).item()
    true_last_reshaped = true_last.unsqueeze(1)
    pred_last_reshaped = pred_last.unsqueeze(1)
    mmd = compute_mmd(true_last_reshaped, pred_last_reshaped).item()
    results = {
        "mse_last_point": mse,
        "mmd_last_point": mmd
    }
    with open(os.path.join(output_dir, "evaluation_results.json"), "w") as f:
        json.dump(results, f, indent=4)
    plt.figure(figsize=(10, 6))
    plt.hist(true_last.numpy(), bins=30, alpha=0.5, label='True', color='r')
    plt.hist(pred_last.numpy(), bins=30, alpha=0.5, label='Pred', color='b')
    plt.title(f'Distribution at Last Time Point\nMSE: {mse:.4f}, MMD: {mmd:.4f}')
    plt.xlabel('Value')
    plt.ylabel('Frequency')
    plt.legend()
    plt.savefig(os.path.join(output_dir, "last_point_distribution.png"))
    num_samples = 5
    plt.figure(figsize=(12, 6))
    for i in range(num_samples):
        plt.plot(all_trues[i].numpy(), color='r', alpha=0.7)
        plt.plot(all_preds[i].numpy(), color='b', alpha=0.7, linestyle='--')
    plt.xlabel('Time')
    plt.ylabel('Value')
    plt.ylim(-0.75, 1.25)
    plt.title('Model Predictions vs True Values')
    plt.savefig(os.path.join(output_dir, "sample_predictions.png"))
    print(f"Results saved in {output_dir}")
    print(f"Last time point MSE: {mse:.6f}")
    print(f"Last time point MMD: {mmd:.6f}")
    return mse, mmd

def train_and_evaluate(config, output_dir="results"):
    seed_everything(config['seed'])
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    total_data, coeffs, times = generate_data(
        config['num_samples'], config['T'], config['N'], 
        config['theta'], config['mu'], config['sigma'], config['X0']
    )
    train_data, train_coeffs, test_data, test_coeffs = split_data(
        total_data, coeffs, config['train_ratio']
    )
    train_loader, test_loader = create_data_loaders(
        train_data, train_coeffs, test_data, test_coeffs, config['batch_size']
    )
    input_dim = 2
    output_dim = 1
    hidden_dim = config.get('hidden_dim', 32)
    num_layers = config.get('num_layers', 1)
    model = NDE_model(
        input_dim=input_dim, 
        hidden_dim=hidden_dim, 
        output_dim=output_dim, 
        num_layers=num_layers, 
        vector_field=NeuralGSDEFunc
    ).to(device)
    optimizer = optim.Adam(model.parameters(), lr=config['lr'])
    criterion = torch.nn.MSELoss()
    for epoch in range(1, config['num_epochs'] + 1):
        model.train()
        total_loss = 0
        for batch in train_loader:
            coeffs_batch = batch[1].to(device)
            times_batch = torch.linspace(0, 1, batch[0].shape[1]).to(device)
            optimizer.zero_grad()
            true = batch[0][:, :, 1].to(device)
            pred = model(coeffs_batch, times_batch).squeeze(-1)
            loss = criterion(pred, true)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        if epoch % 10 == 0 or epoch == config['num_epochs']:
            avg_loss = total_loss / len(train_loader)
            print(f'Epoch {epoch}/{config["num_epochs"]}, Loss: {avg_loss:.6f}')
    model.eval()
    total_loss = 0
    all_preds = []
    all_trues = []
    with torch.no_grad():
        for batch in test_loader:
            coeffs_batch = batch[1].to(device)
            times_batch = torch.linspace(0, 1, batch[0].shape[1]).to(device)
            true = batch[0][:, :, 1].to(device)
            pred = model(coeffs_batch, times_batch).squeeze(-1)
            loss = criterion(pred, true)
            total_loss += loss.item()
            all_preds.append(pred.cpu())
            all_trues.append(true.cpu())
    avg_loss = total_loss / len(test_loader)
    print(f'Test loss: {avg_loss:.6f}')
    all_preds = torch.cat(all_preds, dim=0)
    all_trues = torch.cat(all_trues, dim=0)
    os.makedirs(output_dir, exist_ok=True)
    torch.save(model.state_dict(), os.path.join(output_dir, "model.pt"))
    mse, mmd = save_results_and_plot(all_trues, all_preds, output_dir)
    return model, mse, mmd

if __name__ == "__main__":
    config = {
        'num_samples': 1000,
        'T': 10.0,
        'N': 20,
        'theta': 0.2,
        'mu': 0.0,
        'sigma': 0.1,
        'X0': 1.0,
        'train_ratio': 0.8,
        'batch_size': 16,
        'seed': 42,
        'num_epochs': 100,
        'lr': 1e-3,
        'hidden_dim': 32,
        'num_layers': 1
    }
    output_dir = "gsde_results"
    model, mse, mmd = train_and_evaluate(config, output_dir)
