"""
XXXX

"""

import torch
import torch.nn as nn
from torchdiffeq import odeint
import numpy as np
import os
import pandas as pd
import matplotlib.pyplot as plt
import os
import pickle


method = 'dopri5'
data_size = 200 # each patient has less than 200 time points (upper bound)
batch_time = 10
batch_size = 20
niters = 2000
test_freq = 20
viz = True  # vis? 
gpu = 0
adjoint = False
num_traj = 10000

# dy/dt = y^3*A
true_y0 = torch.tensor([[2., 0.]])
t = torch.linspace(0., 25., data_size)
true_A = torch.tensor([[-0.1, 2.0], [-2.0, -0.1]])

class Lambda(nn.Module):
    def forward(self, t, y):
        return torch.mm(y**3, true_A)

# Generate trajectories without noise
with torch.no_grad():
    true_y = odeint(Lambda(), true_y0, t, method='dopri5')


def add_noise(data, noise_level=0.1):
    noise = noise_level * torch.randn(data.shape)
    return data + noise

# Modify get_batch to include noise addition
def get_batch(with_noise=True, noise_level=0.1):
    s = torch.from_numpy(np.random.choice(np.arange(data_size - batch_time, dtype=np.int64), batch_size, replace=False))
    batch_y0 = true_y[s]  # (M, D)
    batch_t = t[:batch_time]  # (T)
    batch_y = torch.stack([true_y[s + i] for i in range(batch_time)], dim=0)  # (T, M, D)
    if with_noise:
        batch_y = add_noise(batch_y, noise_level=noise_level)
    return batch_y0, batch_t, batch_y

# Example of generating a batch with noise
batch_y0, batch_t, batch_y = get_batch(with_noise=True, noise_level=0.1)

def makedirs(dirname):
    if not os.path.exists(dirname):
        os.makedirs(dirname)


if viz:
    makedirs('png')
    import matplotlib.pyplot as plt
    fig = plt.figure(figsize=(12, 4), facecolor='white')
    ax_traj = fig.add_subplot(131, frameon=False)
    ax_phase = fig.add_subplot(132, frameon=False)
    ax_vecfield = fig.add_subplot(133, frameon=False)
    plt.show(block=False)


def visualize(true_y, pred_y, odefunc, itr):

    if viz:

        ax_traj.cla()
        ax_traj.set_title('Trajectories')
        ax_traj.set_xlabel('t')
        ax_traj.set_ylabel('x,y')
        ax_traj.plot(t.cpu().numpy(), true_y.cpu().numpy()[:, 0, 0], t.cpu().numpy(), true_y.cpu().numpy()[:, 0, 1], 'g-')
        ax_traj.plot(t.cpu().numpy(), pred_y.cpu().numpy()[:, 0, 0], '--', t.cpu().numpy(), pred_y.cpu().numpy()[:, 0, 1], 'b--')
        ax_traj.set_xlim(t.cpu().min(), t.cpu().max())
        ax_traj.set_ylim(-2, 2)
        ax_traj.legend()

        ax_phase.cla()
        ax_phase.set_title('Phase Portrait')
        ax_phase.set_xlabel('x')
        ax_phase.set_ylabel('y')
        ax_phase.plot(true_y.cpu().numpy()[:, 0, 0], true_y.cpu().numpy()[:, 0, 1], 'g-')
        ax_phase.plot(pred_y.cpu().numpy()[:, 0, 0], pred_y.cpu().numpy()[:, 0, 1], 'b--')
        ax_phase.set_xlim(-2, 2)
        ax_phase.set_ylim(-2, 2)

        ax_vecfield.cla()
        ax_vecfield.set_title('Learned Vector Field')
        ax_vecfield.set_xlabel('x')
        ax_vecfield.set_ylabel('y')

        y, x = np.mgrid[-2:2:21j, -2:2:21j]
        dydt = odefunc(0, torch.Tensor(np.stack([x, y], -1).reshape(21 * 21, 2))).cpu().detach().numpy()
        mag = np.sqrt(dydt[:, 0]**2 + dydt[:, 1]**2).reshape(-1, 1)
        dydt = (dydt / mag)
        dydt = dydt.reshape(21, 21, 2)

        ax_vecfield.streamplot(x, y, dydt[:, :, 0], dydt[:, :, 1], color="black")
        ax_vecfield.set_xlim(-2, 2)
        ax_vecfield.set_ylim(-2, 2)

        fig.tight_layout()
        plt.savefig('png/{:03d}'.format(itr))
        plt.draw()
        plt.pause(0.001)


###### ABOVE: ODEINT EXAMPLE ######

noise_level = 0.5

# generate to df
def simulate_trajectories(true_y0, true_A, t, noise_level, drop_prob=0.6, name_idx=str(0)):
    # Simulate true trajectory without noise first
    with torch.no_grad():
        true_y = odeint(Lambda(), true_y0, t, method='dopri5')
    
    # Add noise
    noise = noise_level * torch.randn(true_y.shape)
    noisy_y = true_y + noise

    
    # Drop randomly based on drop_prob
    keep_prob = 1 - drop_prob
    keep_mask = torch.rand(len(t)) < keep_prob
    uneven_t = t[keep_mask]
    uneven_noisy_y = noisy_y[keep_mask, :, :]

    # separate true_A and true_y0
    true_A = true_A.detach().numpy()
    true_y0 = true_y0.detach().numpy()
    flattened_A = [true_A[0, 0], true_A[0, 1], true_A[1, 0], true_A[1, 1]]
    flattened_y0 = [true_y0[0, 0], true_y0[0, 1]]


    
    # Convert to DataFrame
    data = {
        't': uneven_t,
        'x': uneven_noisy_y[:, :, 0].flatten(),
        'y': uneven_noisy_y[:, :, 1].flatten(),
        'trajectory': f'trajectory_{name_idx}',
        'noise_level': noise_level,
        # conditionals...
        'true_A_00': flattened_A[0],
        'true_A_01': flattened_A[1],
        'true_A_10': flattened_A[2],
        'true_A_11': flattened_A[3],
        'true_y0_0': flattened_y0[0],
        'true_y0_1': flattened_y0[1]
    }
    df = pd.DataFrame(data)
    return df, uneven_noisy_y, uneven_t

df, uneven_noisy_y, uneven_t = simulate_trajectories(true_y0, true_A, t, noise_level)

def plot_trajectory(t, true_y, simulated_y, folder='trajectories', name='test'):
    if not os.path.exists(folder):
        os.makedirs(folder)
    
    num_trajectories = true_y.shape[2]
    # for i in range(num_trajectories):
    fig, ax = plt.subplots(1, 2, figsize=(12, 6))
    
    # Phase portrait (x vs. y)
    ax[0].plot(true_y[:, :, 0].numpy(), true_y[:, :, 1].numpy(), 'g-', label='True')
    ax[0].plot(simulated_y[:, :, 0].numpy(), simulated_y[:, :, 1].numpy(), 'b--', label='Simulated')
    ax[0].set_title('Phase Portrait')
    ax[0].set_xlabel('x')
    ax[0].set_ylabel('y')
    ax[0].legend()
    
    # x, y vs. t
    full_t = torch.linspace(0., 25., data_size)
    ax[1].plot(full_t.numpy(), true_y[:, :, 0].numpy(), 'g-', label='True x')
    ax[1].plot(full_t.numpy(), true_y[:, :, 1].numpy(), 'g--', label='True y')
    ax[1].plot(t.numpy(), simulated_y[:, :, 0].numpy(), 'b-', label='Simulated x', alpha=0.5)
    ax[1].plot(t.numpy(), simulated_y[:, :, 1].numpy(), 'b--', label='Simulated y',alpha=0.5)
    ax[1].set_title('State Variables Over Time')
    ax[1].set_xlabel('t')
    ax[1].set_ylabel('State Variables')
    ax[1].legend()
    
    plt.tight_layout()
    plt.savefig(f'{folder}/trajectory_{name}.png')
    plt.close()

# generate and plot (example)
df, uneven_noisy_y, uneven_t = simulate_trajectories(true_y0, true_A, t, noise_level)
plot_trajectory(uneven_t, true_y, uneven_noisy_y, folder='trajectories')

# Generation of multiple trajectories (10000 trajectories), randomize starting points and noise levels
combined_df = pd.DataFrame()
for i in range(num_traj):
    noise_level = np.random.uniform(0, 0.5)
    true_y0 = torch.tensor([[np.random.uniform(-2, 2), np.random.uniform(-2, 2)]])
    true_A = torch.tensor([[-0.1, 2.0], [-2.0, -0.1]])
    t = torch.linspace(0., 25., data_size)
    # randomize how much dropout
    drop_prob = np.random.uniform(0.6, 0.9) # some trajectory have very few points
    df, uneven_noisy_y, uneven_t = simulate_trajectories(true_y0, true_A, t, noise_level, drop_prob, name_idx=str(i))
    # normalize t to 0 - 1
    df['t'] = (df['t'] - df['t'].min()) / (df['t'].max() - df['t'].min())
    
    combined_df = pd.concat([combined_df, df])

# split train, validation, test (0.8, 0.1, 0.1)
from sklearn.model_selection import train_test_split
train, test = train_test_split(combined_df, test_size=0.2)
val, test = train_test_split(test, test_size=0.5)

# print(train.head(5))
# exit(0)

# Save to pickle
general_path = "/home/XXXX-1/x/XXXX-2/scratch/XXXX-3/toy/"
fname = "ode_demo"
train.to_pickle(general_path + fname + 'train.pkl')
val.to_pickle(general_path + fname + 'val.pkl')
test.to_pickle(general_path + fname + 'test.pkl')

# save a combined
combined_lst = {'train': train, 'val': val, 'test': test}
# save to pickle
with open(general_path + fname + 'combined.pkl', 'wb') as f:
    pickle.dump(combined_lst, f)
