from utils.mfg_operator import MFGOperator
from utils.quadratic_model import QuadModel
from generation.generate_quad_data import generate_samples, generate_flows
import numpy as np
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import CosineAnnealingLR

# Set dimension
d = 3
quad_model = QuadModel('operator_train', d=d)

# Set neural network width
W = 64

# Load data
data = np.load(f'data/quad_data_d={d}.npz')
init_data, flow_data = data['arr_0'], data['arr_1']

# Set time discretization
N = 101
disc_points = np.concatenate([[0], np.linspace(1, quad_model.Nt, N).astype(int)])
flow_data = flow_data[:, :, disc_points].reshape(flow_data.shape[0], -1)

print(f'Input data shape: {init_data.shape}')
print(f'Output data shape: {flow_data.shape}')

# Initialize MFGOperator
input_dim = init_data.shape[1]
output_dim = flow_data.shape[1]
architecture = 4 * [W]
mfg_operator = MFGOperator(input_dim, architecture, output_dim, scheduler=CosineAnnealingLR, learning_rate=8e-4)

# Split train/test data and train operator
mfg_operator.split_data(init_data, flow_data, test_size=0.2, batch_size=32)
mfg_operator.train(epochs=1000)
mfg_operator.evaluate_test()

# Save trained model
mfg_operator.save_model(f'models/quad_operator_relu_disc_d={d}.pt')

# Plot four random examples
rand_samples = 4
d = quad_model.d
lower_bound = 0.0
upper_bound = 1.0

rand_inits = generate_samples(rand_samples, d, lower_bound, upper_bound)
rand_flows = generate_flows(rand_inits, quad_model)

fig, axes = plt.subplots(1, rand_samples, figsize=(20, 5))

for i in range(rand_samples):
    pred = mfg_operator.predict(rand_inits[i]).reshape((-1, len(disc_points))).T
    for iS in range(quad_model.d):
        axes[i].plot(
            np.linspace(0, quad_model.T, quad_model.Nt + 1), 
            rand_flows[i, iS, :], 
            label=rf"$u({iS})$",
            color=quad_model.colors[iS], 
            linestyle=quad_model.linestyles[iS], 
            linewidth=quad_model.linewidths[iS],
            
        )
        axes[i].scatter(
            [pt * quad_model.T / quad_model.Nt for pt in disc_points], 
            pred[:, iS], 
            label=rf"$\hat{{u}}({iS})$",
            color=quad_model.colors[iS],
            s=12,
            alpha=0.7
        )

    axes[i].set_xlabel('Time ($t$)')
    axes[i].set_ylabel(r'$u$')
    if d <= 5:
        axes[i].legend()

plt.tight_layout()
plt.savefig(f'plots/quad_mfg_relu_disc_d={d}.png', dpi=200)
plt.show()
