import math
import os
import time

import matplotlib.pyplot as plt
import numpy as np

import ot as pot
import torch
import torchdyn
from torchdyn.core import NeuralODE
from torchdyn.datasets import generate_moons

from torchcfm.conditional_flow_matching import *
from torchcfm.models.models import *
from torchcfm.utils import *
from tqdm import tqdm 
from scipy import stats
from matplotlib.colors import LogNorm
from twomoons_score import two_moons


savedir = "models/new-test-half-moons"
os.makedirs(savedir, exist_ok=True)
device = 'cuda:0'
n_times = 1000
dt = 1/n_times
v = torch.load(f'{savedir}/si.pt')
dv = lambda x: torch.vmap(torch.func.jacrev(v))(x) 
with torch.no_grad():
    node = NeuralODE(torch_wrapper(v, 0), solver="dopri5", sensitivity="adjoint", atol=1e-4, rtol=1e-4).to(device)
    n_samples = 50000
    x = sample_8gaussians(n_samples).to(device)
    dt = 1/n_times
    ts = torch.linspace(0, dt, 2).to(device)
    q = torch.rand(n_samples, 2, 2).to(device)
    les = torch.zeros(n_samples, 2).to(device)
    t = 0
    one_cuda = torch.ones(n_samples,1).to(device)
    traj = node.trajectory(x, t_span=torch.linspace(0,1,n_times).to(device))
    for i in tqdm(range(1,n_times)):
        x = traj[i-1]
        dvt = dt*dv(torch.cat([x, t*one_cuda], dim=1))[:,:,:2] + torch.eye(2).to(device)
        q, r = torch.linalg.qr(dvt @ q)
        les  += torch.log(torch.abs(torch.vmap(torch.diag)(r)))/n_times/dt
        t = t + dt

    print("Computing score of 2 moons")
    tm_obj = two_moons(sig=0.01)
    score_x = tm_obj.score_moons(traj[-1])
    ns_x = torch.sqrt(torch.sum(score_x**2, dim=1)).unsqueeze(1)
    s = score_x/ns_x
    s = s.detach().cpu().numpy()
    x = traj[-1].detach().cpu().numpy()

    print("plotting")
    q0 = q.detach().cpu().numpy()[:,:,0]
    eps = 4.e-1

    fig, ax = plt.subplots()
    angles = np.abs(np.sum(s*q0,axis=1))
    ax.hist(angles, histtype="step", bins=30, color="k")
    ax.grid(True)
    ax.set_xlabel("target scr . LV1 SI ",fontsize=20)
    ax.xaxis.set_tick_params(labelsize=20)
    ax.yaxis.set_tick_params(labelsize=20)
    plt.tight_layout()
    plt.savefig('angles_2moons.png')



    fig, ax = plt.subplots()
    x = x[::100]
    q0 = q0[::100]
    s = s[::100]
    ax.scatter(x[:,0], x[:,1], color="k")
    ax.plot([x[:,0] - eps*q0[:,0], x[:,0] + eps*q0[:,0]], \
        [x[:,1] - eps*q0[:,1], x[:,1] + eps*q0[:,1]], color='darkred', lw=2.5) 
    ax.plot([x[:,0] - eps*s[:,0], x[:,0] + eps*s[:,0]], \
        [x[:,1] - eps*s[:,1], x[:,1] + eps*s[:,1]], color='navy', lw=2.5) 
    ax.plot([x[0,0] - eps*q0[0,0], x[0,0] + eps*q0[0,0]], \
        [x[0,1] - eps*q0[0,1], x[0,1] + eps*q0[0,1]], color='darkred', label='LV1 SI', lw=2.5) 
    ax.plot([x[0,0] - eps*s[0,0], x[0,0] + eps*s[0,0]], \
        [x[0,1] - eps*s[0,1], x[0,1] + eps*s[0,1]], color='navy', label='scr data dist', lw=2.5) 
    ax.legend(fontsize=20)
    ax.grid(True)
    ax.xaxis.set_tick_params(labelsize=20)
    ax.yaxis.set_tick_params(labelsize=20)
    plt.tight_layout()
    plt.savefig('lyap_vec_2moons.png')





    torch.save(q,"lyapunov_vectors.pt")
    torch.save(x,"phase_points.pt")
    print("Lyapunov exponents:", les)

    
    
    


