from utils.mfg_operator import MFGOperator
from utils.cs_model import CSModel
from generation.generate_cs_data import generate_samples, generate_flows
import numpy as np
import matplotlib.pyplot as plt
plt.rcParams.update({'font.size': 12})

from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.optim import AdamW

cs_model = CSModel('operator_train')

# Load data
data = np.load('data/cs_data.npz')
init_data, flow_data = data['arr_0'], data['arr_1']

# Set neural network width
W = 64

# Augment initial parameters and flows by randomly sampling a time step
ts = np.zeros(init_data.shape[0])
sampled_flows = np.zeros((flow_data.shape[0], flow_data.shape[1]))

for i in range(init_data.shape[0]):
    t = np.random.choice(cs_model.Nt)
    sampled_flows[i] = flow_data[i, :, t]
    ts[i] = t * cs_model.T / cs_model.Nt

init_data = np.hstack((init_data, ts.reshape(-1, 1)))

print(f'Input data shape: {init_data.shape}')
print(f'Output data shape: {sampled_flows.shape}')

# Initialize MFGOperator
input_dim = init_data.shape[1]
output_dim = sampled_flows.shape[1]
architecture = 4 * [W]
mfg_operator = MFGOperator(input_dim, architecture, output_dim, optimizer=AdamW, scheduler=CosineAnnealingLR, learning_rate=8e-4)

# Split train/test data and train operator
mfg_operator.split_data(init_data, sampled_flows, test_size=0.2, batch_size=64)
mfg_operator.train(epochs=2000)
mfg_operator.evaluate_test()

# Save trained model
mfg_operator.save_model('models/cs_operator_time.pt')

# Plot four random examples
rand_samples = 4
d = cs_model.NS
lower_bound = 0.0
upper_bound = 10.0

rand_inits = generate_samples(rand_samples, d, lower_bound, upper_bound)
rand_flows = generate_flows(rand_inits, cs_model)

times = np.linspace(0, cs_model.T, cs_model.Nt + 1)

fig, axes = plt.subplots(1, rand_samples, figsize=(20, 5))

for i in range(rand_samples):
    preds = np.zeros((cs_model.NS, len(times)))
    for time in times:
        preds[:, round(time * cs_model.Nt / cs_model.T)] = mfg_operator.predict(np.append(rand_inits[i], time))
    for iS in range(cs_model.NS):

        axes[i].plot(
            times,
            rand_flows[i, iS, :], 
            label=rf"$u({cs_model.get_state(iS)})$",
            color=cs_model.colors[iS], 
            linestyle=cs_model.linestyles[iS], 
            linewidth=cs_model.linewidths[iS]
        )

        axes[i].plot(
            times,
            preds[iS, :], 
            label=rf"$\hat{{u}}({cs_model.get_state(iS)})$",
            color=cs_model.colors[iS], 
            linestyle=cs_model.linestyles[iS], 
            linewidth=6,
            alpha=0.5
        )

    axes[i].set_xlabel('Time ($t$)')
    axes[i].set_ylabel(r'$u$')
    axes[i].legend()

plt.tight_layout()
plt.savefig('plots/cs_mfg_relu_time.png', dpi=200)
plt.show()
