import os
import time

import matplotlib.pyplot as plt
import numpy as np

import torch
from torchdyn.core import NeuralODE

from torchcfm.conditional_flow_matching import *
from torchcfm.models.models import *
from torchcfm.utils import *
from tqdm import tqdm 
from scipy import stats

savedir = "models/new-test-half-moons"
os.makedirs(savedir, exist_ok=True)

def plot_samples(p, eps):
    x, y = p.T
    fig, ax = plt.subplots()

    xmin, xmax = min(x) - 1, max(x) + 1
    ymin, ymax = min(y) - 1, max(y) + 1
    X, Y = np.mgrid[xmin:xmax:100j, ymin:ymax:100j]
    positions = np.vstack([X.ravel(), Y.ravel()])
    values = np.vstack([x, y])

    # Calculate the kernel density estimate
    kernel = stats.gaussian_kde(values)
    Z = np.reshape(kernel(positions).T, X.shape) + 1.e-10
    z_min, z_max = Z.min(), Z.max()
    levels = np.linspace(z_min, z_max, 10)


    # Create the filled contour plot for the density
    contour = ax.contourf(X, Y, Z, levels=levels, cmap='RdBu_r', vmin=z_min, vmax=z_max)
    cbar = plt.colorbar(contour, ax=ax)
    cbar.ax.tick_params(labelsize=16)
    
    # Add contour lines for better visualization
    #contour_lines = ax.contour(X, Y, Z, levels=levels, colors='white', alpha=0.5, linewidths=0.5)


    # Add scatter plot of original data points with transparency
    #ax.scatter(x, y, alpha=0.3, s=10, c='lightgray', edgecolor=None)

    # Add labels and title
    #ax.set_xlabel('X')
    #ax.set_ylabel('Y')
    ax.set_title(f'Conditional flow matching, eps = {eps} ', fontsize=16)
    ax.xaxis.set_tick_params(labelsize=20)
    ax.yaxis.set_tick_params(labelsize=20)
    #ax.set_xlim([-1,1])
    #ax.set_ylim([-1,1])

    # Show the plot
    plt.tight_layout()
    plt.savefig(f'cfm_8gaussiansTo2moons_{eps}.png')
    plt.show()


#%%time
sigma = 0.01
dim = 2
batch_size = 256
device = "cuda:0"
model = MLP(dim=dim, time_varying=True).to(device)
optimizer = torch.optim.Adam(model.parameters())
FM = VariancePreservingConditionalFlowMatcher(sigma=sigma)
#FM = ConditionalFlowMatcher(sigma=sigma)
#FM = ExactOptimalTransportConditionalFlowMatcher(sigma=sigma)
start = time.time()
for k in tqdm(range(20000)):
    optimizer.zero_grad()

    x0 = sample_8gaussians(batch_size).to(device)
    #x0 = torch.randn(batch_size, 2).to(device)
    x1 = sample_moons(batch_size).to(device)

    t, xt, ut = FM.sample_location_and_conditional_flow(x0, x1)
    xt_new = torch.cat([xt, t[:, None]], dim=-1)
    
    vt = model(xt_new)
    loss = torch.mean((vt - ut) ** 2)

    loss.backward()
    optimizer.step()

    if (k + 1) % 10000 == 0:
        end = time.time()
        print(f"{k+1}: loss {loss.item():0.3f} time {(end - start):0.2f}")
        start = end
        #eps = 0
        #node = NeuralODE(
        #    torch_wrapper(model, eps), solver="dopri5", sensitivity="adjoint", atol=1e-4, rtol=1e-4
        #).to(device)
        #with torch.no_grad():
        #    traj = node.trajectory(
        #        sample_8gaussians(50000).to(device),
        #        t_span=torch.linspace(0, 1, 100).to(device),
        #    )
        #    plot_samples(traj[-1].cpu().numpy(),eps)
torch.save(model, f"{savedir}/si.pt")
