"""
Dynamics on distorted S1, neural ODE models

Detailed result comparison is done in s1.py
"""
import copy
import pickle
import matplotlib.pyplot as plt
import numpy as np
import time

import torch
import torch.nn as nn
from torchdiffeq import odeint
torch.set_default_dtype(torch.float64)

# ----------------------------
# Choose the D parameter and the relevant epoches
# ----------------------------
# D = 0.0
# epochs = 1000
# D = 0.1
# epochs = 4000
D = 0.3
epochs = 4000
# D = 0.5
# epochs = 6000

s5 = np.sqrt(5)
def make_dyn(K, D):
    # \dot{\theta} = 3/2 - \cos(\theta)
    def dyn(tt, K=K, D=D):
        vv = 2 * np.arctan(np.tan(s5 * tt / 4) / s5)
        rr = 1 + D * np.cos(K * vv)
        uu = np.array([
            rr * np.cos(vv),
            rr * np.sin(vv)]).T
        return uu
    return dyn

# Neural ODE approach
class ODEFunc(nn.Module):
    def __init__(self, hidden_dim):
        super(ODEFunc, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(2, hidden_dim),
            nn.PReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.PReLU(),
            nn.Linear(hidden_dim, 2),
        )
        for m in self.net.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, mean=0, std=1/hidden_dim)
                nn.init.constant_(m.bias, val=0)

    def forward(self, t, y):
        return self.net(y)

# ------------------------
# Tunable parameters
# ------------------------
batch_size = 20
time_steps = 20
hidden_dim = 32
method = 'rk4'

# ------------------------
# Data generation
# ------------------------
T = 128
Nt = 3201
tt = np.linspace(0, T, Nt)
dt = tt[1] - tt[0]
Ntrain = (Nt - 1) // 16
Ltest = Nt
t_sim = np.arange(Ltest) * dt
t_plt = t_sim
fdyn = make_dyn(3, D)
uu = fdyn(tt)

FS = 16

# ------------------------
# Training
# ------------------------
ifrun = 0  # Train the NODE for the chosen D
ifplt = 0  # Plot and compare the prediction for sanity check

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def get_batch(data, batch_size, time_steps, device):
    s = torch.from_numpy(np.random.choice(np.arange(len(data) - time_steps, dtype=np.int64), batch_size, replace=False))
    batch_y0 = data[s].to(device)
    batch_t = (torch.arange(time_steps)*dt).to(device)
    batch_y = torch.stack([data[s + i] for i in range(time_steps)], dim=0).to(device)  # (T, M, D)
    return batch_y0, batch_t, batch_y

if ifrun:
    data_train = torch.tensor(uu[:Ntrain], dtype=torch.float64).to(device)
    data_test = torch.tensor(uu, dtype=torch.float64).to(device)
    a0_test = torch.tensor(uu[0], dtype=torch.float64).to(device)

    # Training Neural ODE
    func = ODEFunc(hidden_dim).to(device)
    optimizer = torch.optim.Adam(func.parameters(), lr=1e-4)

    t1 = time.time()
    loss_history = []
    for epoch in range(epochs):
        epoch_loss = 0
        for _ in range(Ntrain // batch_size):
            batch_y0, batch_t, batch_y = get_batch(data_train, batch_size, time_steps, device)

            # Forward pass using ODE solver
            pred = odeint(func, batch_y0, batch_t, method=method)

            # Compute loss (MSE)
            loss = ((pred - batch_y) ** 2).mean()
            epoch_loss += loss.item()

            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        loss_history.append(epoch_loss)
        if (epoch+1) % 10 == 0:
            print(f'Epoch {epoch+1}/{epochs}, Loss: {epoch_loss:4.3e}')
    t2 = time.time()

    with open(f'./res/losses_case_{D}.pkl', 'wb') as f:
        pickle.dump(loss_history, f)

    # Testing Neural ODE
    t3 = time.time()
    with torch.no_grad():
        a_nde = odeint(func, a0_test.unsqueeze(0), torch.tensor(t_sim).to(device), method=method).squeeze().cpu().numpy()
    t4 = time.time()

    print(t2-t1, t4-t3)
    np.save(f'./res/s1_nde_{int(10*D)}.npy', a_nde)

if ifplt:
    with open(f'./res/losses_case_{D}.pkl', 'rb') as f:
        loss = pickle.load(f)
    plt.figure()
    plt.semilogy(loss)
    plt.xlabel('Epoch')
    plt.ylabel('Loss')

    a_tru = uu
    a_nde = np.load(f'./res/s1_nde_{int(10*D)}.npy')
    f, ax = plt.subplots(nrows=2, sharex=True, figsize=(10, 6))
    for _j in range(2):
        ax[_j].plot(t_plt, a_nde[:, _j], 'g-', label='NODE')
        ax[_j].plot(t_plt, a_tru[:, _j], 'k--', label='Truth')
        ax[_j].set_ylabel(f'$x_{_j + 1}$', fontsize=FS)
        ax[_j].set_title(f'D={D}', fontsize=FS)
        ax[_j].tick_params(axis='both', which='major', labelsize=FS)

plt.show()
