from utils.mfg_operator import MFGOperator
from utils.quadratic_model import QuadModel
from generation.generate_quad_data import generate_samples, generate_flows

import numpy as np
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import CosineAnnealingLR
import torch.optim as optim
import torch.nn as nn

# Set dimension
d = 3
quad_model = QuadModel('operator_train', d=d)

# Set width
W = 64

# Flag to indicate whether to use ResNet architecture or not
use_resnet = True
resnet_tag = '_resnet' if use_resnet else ''

# Load data
data = np.load(f'data/quad_data_d={d}.npz')
init_data, flow_data = data['arr_0'], data['arr_1']

# Augment initial parameters and flows by randomly sampling a time step
ts = np.zeros(init_data.shape[0])
sampled_flows = np.zeros((flow_data.shape[0], flow_data.shape[1]))

for i in range(init_data.shape[0]):
    t = np.random.choice(quad_model.Nt + 1)
    sampled_flows[i] = flow_data[i, :, t]
    ts[i] = t * quad_model.T / quad_model.Nt

init_data = np.hstack((init_data, ts.reshape(-1, 1)))

print(f'Input data shape: {init_data.shape}')
print(f'Output data shape: {sampled_flows.shape}')

# Initialize MFGOperator, with width W
input_dim = init_data.shape[1]
output_dim = sampled_flows.shape[1]

if use_resnet:
    mfg_operator = MFGOperator(
    input_dim=input_dim,
    architecture={'width': 128, 'depth': 4, 'dropout': 0.05},
    output_dim=output_dim,
    resnet=True,
    scheduler=CosineAnnealingLR, 
    optimizer=optim.AdamW,
    loss_function=nn.SmoothL1Loss(reduction='sum'),
    learning_rate=8e-4
)
else:
    mfg_operator = MFGOperator(
        input_dim=input_dim, 
        architecture=4 * [W], 
        output_dim=output_dim, 
        scheduler=CosineAnnealingLR, 
        optimizer=optim.AdamW,
        loss_function=nn.SmoothL1Loss(reduction='sum'),
        learning_rate=8e-4
    )

# Split train/test data and train operator
mfg_operator.split_data(init_data, sampled_flows, test_size=0.2, batch_size=64)
mfg_operator.train(epochs=2000)
mfg_operator.evaluate_test()

# Plot four random examples
rand_samples = 4
d = quad_model.d
lower_bound = 0.0
upper_bound = 1.0

rand_inits = generate_samples(rand_samples, d, lower_bound, upper_bound)
rand_flows = generate_flows(rand_inits, quad_model)

times = np.linspace(0, quad_model.T, quad_model.Nt + 1)

fig, axes = plt.subplots(1, rand_samples, figsize=(20, 5))

# Save trained model
mfg_operator.save_model(f'models/quad_operator_time_d={d}{resnet_tag}.pt')

for i in range(rand_samples):
    preds = np.zeros((quad_model.d, len(times)))
    for time in times:
        preds[:, round(time * quad_model.Nt / quad_model.T)] = mfg_operator.predict(np.append(rand_inits[i], time))
    for iS in range(np.min([quad_model.d, 10])):
        axes[i].plot(
            times,
            rand_flows[i, iS, :], 
            label=rf"$u({iS})$",
            color=quad_model.colors[iS], 
            linestyle=quad_model.linestyles[iS], 
            linewidth=quad_model.linewidths[iS]
        )

        axes[i].plot(
            times,
            preds[iS, :], 
            label=rf"$\hat{{u}}({iS})$",
            color=quad_model.colors[iS], 
            linestyle=quad_model.linestyles[iS], 
            linewidth=4,
            alpha=0.5
        )

    axes[i].set_xlabel('Time ($t$)')
    axes[i].set_ylabel(r'$u$')
    if d < 5:
        axes[i].legend()

plt.tight_layout()
plt.savefig(f'plots/quad_mfg_relu_time_d={d}{resnet_tag}.png', dpi=200)
plt.show()
