from utils.loss import *
from utils.data_load import ERA5Dataset
from utils.model import Schroedinger_NO_2d_time, count_params
from timeit import default_timer
import os

####### Before Training #######

torch.cuda.set_device(device=0)
prop = 'wind_v'     # wind_u, wind_v


save_model_name = 'Schroedinger_NO' + '_' + prop

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
DATA_PATH = os.path.join(BASE_DIR, 'data') + '/era_512_512.pt'

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print("device indicates {}".format(device))

####### Configuration #######

batch_size = 2

s = 512

data_set = ERA5Dataset(data_path = DATA_PATH,
                       raw_resolution=[512, 512, 80], 
                       sample_resolution=[512, 512, 20], 
                       eval_resolution=[512, 512, 20], 
                       in_t=1, out_t=1, duration_t=16,
                       train_day=4, test_day=1,
                       train_batchsize=batch_size, eval_batchsize=batch_size, 
                       normalize=False, normalizer_type='PGN',
                       prop=prop, sub=False)

ntrain = data_set.ntrain
ntest = data_set.ntest

print("ntrain: {}, ntest: {}".format(ntrain, ntest))

modes = 10
width = 20

learning_rate = 0.001
epochs = 50
iterations = epochs*(ntrain//batch_size)

MODEL_PATH = os.path.join(BASE_DIR, 'checkpoints') + '/epo_' + str(epochs) + '_' + save_model_name + '.pt'

####### Data #######

train_loader, test_loader, y_normalizer = data_set.train_loader, data_set.test_loader, data_set.y_normalizer

####### Model, Training Modules #######

model = Schroedinger_NO_2d_time(modes, modes, n_hidden=width).cuda()
print(count_params(model))

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=iterations)

myloss = LpLoss(size_average=False)


####### Training #######

best_test_l2 = 1e17
best_test_l2_train_loss = 0.0
best_test_l2_to_l1 = 0.0
best_epoch = 0

time_costs = []

# Create lists to store training and testing loss
train_loss_history = []
test_loss_history = []


for ep in range(epochs):

    model.train()

    t1 = default_timer()

    train_l2 = 0

    for x, y in train_loader:

        x, y = x.cuda(), y.cuda()

        optimizer.zero_grad()
        out = model(x).reshape(batch_size, s, s)
        loss = myloss(out.view(batch_size,-1), y.view(batch_size,-1))
        loss.backward()

        optimizer.step()
        scheduler.step()
        train_l2 += loss.item()

    t2 = default_timer()

    time_costs.append(t2 - t1)

    model.eval()
    test_l2 = 0.0

    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.cuda(), y.cuda()
            out = model(x).reshape(batch_size, s, s)
            test_l2 += myloss(out.view(batch_size,-1), y.view(batch_size,-1)).item()


    train_l2/= ntrain
    test_l2 /= ntest

    # Append the current epoch's loss to the history lists
    train_loss_history.append(train_l2)
    test_loss_history.append(test_l2)


    if test_l2 < best_test_l2:
        best_test_l2 = test_l2
        best_test_l2_train_loss = train_l2
        best_epoch = ep
        torch.save(model, MODEL_PATH)
        print("New checkpoint model saved at epoch {}".format(best_epoch))

    print("####### Epoch {} #######".format(ep))


    print("Train_L2: {}".format(train_l2))
    print("Test_l2: {}".format(test_l2))
    print("Training time cost in this epoch: {}".format(t2-t1))

print("The best test_l2 on {} is: {}, happens on epoch {}".format(prop, best_test_l2, best_epoch))
print("The corresponding best_test_epoch train_l2 is: {}".format(best_test_l2_train_loss))
print('\n')

# Calculate and display the maximum, minimum, and average time cost
max_time_cost = max(time_costs)
min_time_cost = min(time_costs)
avg_time_cost = sum(time_costs) / len(time_costs)

print("\nTime Cost Statistics:")
print("Max Time Cost: {:.4f} seconds".format(max_time_cost))
print("Min Time Cost: {:.4f} seconds".format(min_time_cost))
print("Avg Time Cost: {:.4f} seconds".format(avg_time_cost))