from utils.mfg_operator import MFGOperator
from utils.cs_model import CSModel
from generation.generate_cs_data import generate_samples, generate_flows
import numpy as np
import matplotlib.pyplot as plt
plt.rcParams.update({'font.size': 14})
from torch.optim.lr_scheduler import CosineAnnealingLR

cs_model = CSModel('operator_train')

# Set neural network width
W = 64

# Load data
data = np.load('data/cs_data.npz')
init_data, flow_data = data['arr_0'], data['arr_1']

# Set time discretization
N = 51
disc_points = np.linspace(1, cs_model.Nt, N).astype(int)
flow_data = flow_data[:, :, disc_points].reshape(flow_data.shape[0], -1)

print(f'Input data shape: {init_data.shape}')
print(f'Output data shape: {flow_data.shape}')

# Initialize MFGOperator
input_dim = init_data.shape[1]
output_dim = flow_data.shape[1]
architecture = 4 * [W]
mfg_operator = MFGOperator(input_dim, architecture, output_dim, scheduler=CosineAnnealingLR, learning_rate=8e-4)

# Split train/test data and train operator
mfg_operator.split_data(init_data, flow_data, test_size=0.2, batch_size=32)
mfg_operator.train(epochs=1000)
mfg_operator.evaluate_test()

# Plot four random examples
rand_samples = 4
d = cs_model.NS
lower_bound = 0.0
upper_bound = 10.0

rand_inits = generate_samples(rand_samples, d, lower_bound, upper_bound)
rand_flows = generate_flows(rand_inits, cs_model)

fig, axes = plt.subplots(1, rand_samples, figsize=(20, 5))

for i in range(rand_samples):
    pred = mfg_operator.predict(rand_inits[i]).reshape((-1, len(disc_points))).T
    for iS in range(cs_model.NS):
        axes[i].plot(
            np.linspace(0, cs_model.T, cs_model.Nt + 1), 
            rand_flows[i, iS, :], 
            label=rf"$u({cs_model.get_state(iS)})$",
            color=cs_model.colors[iS], 
            linestyle=cs_model.linestyles[iS], 
            linewidth=cs_model.linewidths[iS]
        )
        axes[i].scatter(
            [pt * cs_model.T / cs_model.Nt for pt in disc_points], 
            pred[:, iS], 
            label=rf"$\hat{{u}}({cs_model.get_state(iS)})$",
            color=cs_model.colors[iS]
        )

    axes[i].set_xlabel('Time ($t$)')
    axes[i].set_ylabel(r'$u$')
    axes[i].legend()

plt.tight_layout()
plt.savefig('plots/cs_mfg_relu_disc.png', dpi=200)
plt.show()

# Save trained model
mfg_operator.save_model('models/cs_operator_relu_disc.pt')
