import math
import os
import time

import matplotlib.pyplot as plt
import numpy as np

import ot as pot
import torch
import torchdyn
from torchdyn.core import NeuralODE
from torchdyn.datasets import generate_moons

from torchcfm.conditional_flow_matching import *
from torchcfm.models.models import *
from torchcfm.utils import *
from tqdm import tqdm 
from scipy import stats
from matplotlib.colors import LogNorm
from twomoons_score import two_moons
def plot_samples(p, eps):
    x, y = p.T
    fig, ax = plt.subplots()

    xmin, xmax = min(x) - 1, max(x) + 1
    ymin, ymax = min(y) - 1, max(y) + 1
    X, Y = np.mgrid[xmin:xmax:100j, ymin:ymax:100j]
    positions = np.vstack([X.ravel(), Y.ravel()])
    values = np.vstack([x, y])

    # Calculate the kernel density estimate
    kernel = stats.gaussian_kde(values)
    Z = np.reshape(kernel(positions).T, X.shape) + 1.e-10
    z_min, z_max = Z.min(), Z.max()
    levels = np.linspace(z_min, z_max, 10)


    # Create the filled contour plot for the density
    contour = ax.contourf(X, Y, Z, levels=levels, cmap='RdBu_r', vmin=z_min, vmax=z_max)
    cbar = plt.colorbar(contour, ax=ax)
    cbar.ax.tick_params(labelsize=16)
    
    # Add contour lines for better visualization
    #contour_lines = ax.contour(X, Y, Z, levels=levels, colors='white', alpha=0.5, linewidths=0.5)


    # Add scatter plot of original data points with transparency
    #ax.scatter(x, y, alpha=0.3, s=10, c='lightgray', edgecolor=None)

    # Add labels and title
    #ax.set_xlabel('X')
    #ax.set_ylabel('Y')
    ax.set_title(f'Stoc. Interp., eps = {eps} ', fontsize=16)
    ax.xaxis.set_tick_params(labelsize=20)
    ax.yaxis.set_tick_params(labelsize=20)
    #ax.set_xlim([-1,1])
    #ax.set_ylim([-1,1])

    # Show the plot
    plt.tight_layout()
    plt.savefig(f'cfm_8gaussiansTo2moons_{eps}.png')
    plt.show()



savedir = "models/new-test-half-moons"
os.makedirs(savedir, exist_ok=True)
device = 'cuda:0'
eps = 1.0
n_times = 100
v = torch.load(f'{savedir}/si.pt')
with torch.no_grad():
    node = NeuralODE(torch_wrapper(v, eps), solver="dopri5", sensitivity="adjoint", atol=1e-4, rtol=1e-4).to(device)
    n_samples = 50000
    x = sample_8gaussians(n_samples).to(device)
    traj = node.trajectory(x, t_span=torch.linspace(0,1,n_times).to(device))
    traj = traj[-1].detach().cpu().numpy()
    plot_samples(traj,eps)    
    
    


