####### Import Lib #######

from utils.model import Schroedinger_NO_1d_Decompose, count_params      # change model
from utils.data_load import *
from utils.loss import *

from timeit import default_timer

import os

####### Checking Before Training #######

torch.cuda.set_device(device=0)

save_model_name = 'Schroedinger_KNO'                                     # change names

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
DATA_DIR = os.path.join(BASE_DIR, 'data')

DATA_PATH = DATA_DIR + '/burgers_data_R10.mat'

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print("device indicates {}".format(device))

####### Configuration #######

ntrain = 1000
ntest = 100

sub = 2**6 #subsampling rate
h = 2**13 // sub #total grid size divided by the subsampling rate
s = h

batch_size = 20
learning_rate = 0.001
epochs = 500
iterations = epochs*(ntrain//batch_size)

modes = 16
width = 64

MODEL_PATH = os.path.join(BASE_DIR, 'checkpoints') + '/epo_' + str(epochs) + '_' + save_model_name + '_resolution_' + str(s) + '.pt'

####### Data #######

x_test, y_test, train_loader, test_loader = Loading_Data(DATA_PATH, sub, s, ntrain, ntest, batch_size)

####### Model, Training Modules #######

# model
model = Schroedinger_NO_1d_Decompose(modes, width).cuda()
print(count_params(model))

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=iterations)

myloss = LpLoss(size_average=False)

# Define the L1 loss function
myloss_l1 = LpLoss(p=1, size_average=False)

####### Training #######

best_test_l2 = 1e17
best_test_l2_train_loss = 0.0
best_test_l2_to_l1 = 0.0
best_epoch = 0

time_costs = []

# Create lists to store training and testing loss
train_loss_history = []
test_loss_history = []

# Create lists to store testing L1 loss
test_l1_history = []

for ep in range(epochs):

    model.train()

    t1 = default_timer()

    train_l2 = 0

    for x, y in train_loader:

        x, y = x.cuda(), y.cuda()

        optimizer.zero_grad()
        out = model(x)

        loss = myloss(out.view(batch_size,-1), y.view(batch_size,-1))
        loss.backward()

        optimizer.step()
        scheduler.step()
        train_l2 += loss.item()

    t2 = default_timer()

    time_costs.append(t2 - t1)

    model.eval()
    test_l2 = 0.0
    test_l1 = 0.0
    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.cuda(), y.cuda()

            out = model(x)

            test_l2 += myloss(out.view(batch_size,-1), y.view(batch_size,-1)).item()
            test_l1 += myloss_l1(out.view(batch_size,-1), y.view(batch_size,-1)).item()

    train_l2/= ntrain
    test_l2 /= ntest
    test_l1 /= ntest

    # Append the current epoch's loss to the history lists
    train_loss_history.append(train_l2)
    test_loss_history.append(test_l2)
    test_l1_history.append(test_l1)

    if test_l2 < best_test_l2:
        best_test_l2 = test_l2
        best_test_l2_train_loss = train_l2
        best_test_l2_to_l1 = test_l1
        best_epoch = ep
        torch.save(model, MODEL_PATH)
        print("New checkpoint model saved at epoch {}".format(best_epoch))

    print("####### Epoch {} #######".format(ep))

    print("Train_L2: {}".format(train_l2))
    print("Test_l2: {}".format(test_l2))
    print("Test_l1: {}".format(test_l1))
    print("Training time cost in this epoch: {}".format(t2-t1))

print("The best test_l2 is: {}, happens on epoch {}".format(best_test_l2, best_epoch))
print("The corresponding best_test_epoch train_l2 is: {}".format(best_test_l2_train_loss))
print("The corresponding best_test_epoch test_l1 is {}".format(best_test_l2_to_l1))

# Calculate and display the maximum, minimum, and average time cost
max_time_cost = max(time_costs)
min_time_cost = min(time_costs)
avg_time_cost = sum(time_costs) / len(time_costs)

print("\nTime Cost Statistics:")
print("Max Time Cost: {:.4f} seconds".format(max_time_cost))
print("Min Time Cost: {:.4f} seconds".format(min_time_cost))
print("Avg Time Cost: {:.4f} seconds".format(avg_time_cost))