from utils.loss import *
from utils.data_load import *
from utils.model import Schroedinger_NO_2d_time, count_params
from timeit import default_timer
import os
import numpy as np

####### Before Training #######

torch.cuda.set_device(device=0)

save_model_name = 'Schroedinger_KNO'

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
DATA_PATH = os.path.join(BASE_DIR, 'data')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print("device indicates {}".format(device))


####### Configuration #######

ntrain = 200
ntest = 20

modes = 12
width = 24

batch_size = 20
learning_rate = 0.001
epochs = 500
iterations = epochs*(ntrain//batch_size)

sub = 1
S = 64
T_in = 10
T = 10
step = 1

MODEL_PATH_FULL = os.path.join(BASE_DIR, 'checkpoints') + '/epo_' + str(epochs) + '_' + save_model_name + '_full.pt'

####### Data #######

x_test, y_test, train_loader, test_loader = Loading_Data(DATA_PATH, sub, S, ntrain, ntest, batch_size, T_in, T, inverse=False)


####### Model, Training Modules #######

model = Schroedinger_NO_2d_time(modes, modes, n_hidden=width).cuda()
print(count_params(model))

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=iterations)

myloss = LpLoss(size_average=False)

# Define the L1 loss function
myloss_l1 = LpLoss(p=1, size_average=False)

####### Training #######


best_test_l2_full = 1e17

best_epoch_full = 0

best_test_l2_train_loss_full = 0.0
best_test_l2_to_l1_full = 0.0

time_costs = []

# Create lists to store training and testing loss
train_loss_history_full = []
test_loss_history_full = []

# Create lists to store testing L1 loss
test_l1_history_full = []


for ep in range(epochs):

    model.train()

    t1 = default_timer()

    train_l2_full = 0
    for xx, yy in train_loader:
        loss = 0
        xx = xx.to(device)
        yy = yy.to(device)

        for t in range(0, T, step):
            y = yy[..., t:t + step, :].squeeze()
            im = model(xx)
            loss += myloss(im.reshape(batch_size, -1), y.reshape(batch_size, -1))

            if t == 0:
                pred = im
            else:
                pred = torch.cat((pred, im), -1)

            xx = torch.cat((xx[..., step*2:], im), dim=-1)

        l2_full = myloss(pred.reshape(batch_size, -1), yy.reshape(batch_size, -1))
        train_l2_full += l2_full.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

    t2 = default_timer()

    time_costs.append(t2 - t1)

    test_l2_full = 0
    test_l1_full = 0
    with torch.no_grad():
        for xx, yy in test_loader:
            loss = 0
            error_l1 = 0
            xx = xx.to(device)
            yy = yy.to(device)

            for t in range(0, T, step):
                y = yy[..., t:t + step, :].squeeze()
                im = model(xx)
                loss += myloss(im.reshape(batch_size, -1), y.reshape(batch_size, -1))
                error_l1 += myloss_l1(im.reshape(batch_size, -1), y.reshape(batch_size, -1))

                if t == 0:
                    pred = im
                else:
                    pred = torch.cat((pred, im), -1)

                xx = torch.cat((xx[..., step*2:], im), dim=-1)

            test_l2_full += myloss(pred.reshape(batch_size, -1), yy.reshape(batch_size, -1)).item()
            test_l1_full += myloss_l1(pred.reshape(batch_size, -1), yy.reshape(batch_size, -1)).item()


    train_l2_full = train_l2_full / ntrain
    test_l2_full = test_l2_full / ntest

    # Append the current epoch's loss to the history lists
    train_loss_history_full.append(train_l2_full)
    test_loss_history_full.append(test_l2_full)
    test_l1_history_full.append(test_l1_full)


    if test_l2_full < best_test_l2_full:
        best_test_l2_full = test_l2_full
        best_test_l2_train_loss_full = train_l2_full
        best_test_l2_to_l1_full = test_l1_full
        best_epoch_full = ep
        torch.save(model, MODEL_PATH_FULL)
        print("New checkpoint model saved in epoch {}".format(best_epoch_full))


    print("####### Epoch {} #######".format(ep))

    print("Train_L2 Full: {}".format(train_l2_full))
    print("Test_l2 Full: {}".format(test_l2_full))
    print("Time cost in this epoch: {}".format(t2-t1))


print("The best test_l2_full is: {}, happend on epoch {}".format(best_test_l2_full, best_epoch_full))


# Calculate and display the maximum, minimum, and average time cost
max_time_cost = max(time_costs)
min_time_cost = min(time_costs)
avg_time_cost = sum(time_costs) / len(time_costs)

print("\nTime Cost Statistics:")
print("Max Time Cost: {:.4f} seconds".format(max_time_cost))
print("Min Time Cost: {:.4f} seconds".format(min_time_cost))
print("Avg Time Cost: {:.4f} seconds".format(avg_time_cost))