from utils.loss import *
from utils.data_load import *
from utils.model import Schroedinger_NO_3d_time_Diff, count_params
from timeit import default_timer
import os

####### Before Training #######

random_seed = 0

torch.manual_seed(random_seed)
np.random.seed(random_seed)

torch.cuda.set_device(device=2)


MODEL_PATH = ''

save_model_name = ''


BASE_DIR = os.path.dirname(os.path.abspath(__file__))

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print("device indicates {}".format(device))


####### Configuration #######

ntrain = 220
ntest = 40

modes = 16
width = 32

batch_size = 10
learning_rate = 0.001
epochs = 250
iterations = epochs*(ntrain//batch_size)

sub = 4
S = 128
res = S//sub
print("res:{}".format(res))
T_in = 20
T = 20
step = 1

####### Data #######

x_test, y_test, train_loader, test_loader = Loading_Data('none', sub, S, ntrain, ntest, batch_size, T_in, T, inverse=False)


####### Model, Training Modules #######

model = Schroedinger_NO_3d_time_Diff(modes, modes, modes, n_hidden=width).cuda()
print(count_params(model))

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=iterations)

myloss = LpLoss(size_average=False)

# Define the L1 loss function
myloss_l1 = LpLoss(p=1, size_average=False)

####### Training #######

best_test_l2 = 1e17
best_test_l2_train_loss = 0.0
best_test_l2_to_l1 = 0.0
best_epoch = 0

time_costs = []

# Create lists to store training and testing loss
train_loss_history = []
test_loss_history = []

# Create lists to store testing L1 loss
test_l1_history = []

for ep in range(epochs):

    model.train()

    t1 = default_timer()

    train_l2 = 0

    for x, y in train_loader:

        x, y = x.cuda(), y.cuda()

        optimizer.zero_grad()
        out = model(x).reshape(batch_size, res, res, res, 4)

        loss = myloss(out.reshape(batch_size,-1), y.reshape(batch_size,-1))
        loss.backward()

        optimizer.step()
        scheduler.step()
        train_l2 += loss.item()

    t2 = default_timer()

    time_costs.append(t2 - t1)

    model.eval()
    test_l2 = 0.0
    test_l1 = 0.0
    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.cuda(), y.cuda()

            out = model(x).reshape(batch_size, res, res, res, 4)

            test_l2 += myloss(out.reshape(batch_size,-1), y.reshape(batch_size,-1)).item()
            test_l1 += myloss_l1(out.reshape(batch_size,-1), y.reshape(batch_size,-1)).item()

    train_l2/= ntrain
    test_l2 /= ntest
    test_l1 /= ntest

    # Append the current epoch's loss to the history lists
    train_loss_history.append(train_l2)
    test_loss_history.append(test_l2)
    test_l1_history.append(test_l1)

    if test_l2 < best_test_l2:
        best_test_l2 = test_l2
        best_test_l2_train_loss = train_l2
        best_test_l2_to_l1 = test_l1
        best_epoch = ep
        torch.save(model, MODEL_PATH + save_model_name + '.pt')
        print("New checkpoint model saved at epoch {}".format(best_epoch))

    print("####### Epoch {} #######".format(ep))

    print("Train_L2: {}".format(train_l2))
    print("Test_l2: {}".format(test_l2))
    print("Test_l1: {}".format(test_l1))
    print("Training time cost in this epoch: {}".format(t2-t1))

print("The best test_l2 is: {}, happens on epoch {}".format(best_test_l2, best_epoch))
print("The corresponding best_test_epoch train_l2 is: {}".format(best_test_l2_train_loss))
print("The corresponding best_test_epoch test_l1 is {}".format(best_test_l2_to_l1))

# Calculate and display the maximum, minimum, and average time cost
max_time_cost = max(time_costs)
min_time_cost = min(time_costs)
avg_time_cost = sum(time_costs) / len(time_costs)

print("\nTime Cost Statistics:")
print("Max Time Cost: {:.4f} seconds".format(max_time_cost))
print("Min Time Cost: {:.4f} seconds".format(min_time_cost))
print("Avg Time Cost: {:.4f} seconds".format(avg_time_cost))
