import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import time
import argparse
from timeit import default_timer
from utilities3 import *
from Adam import Adam
from NS_model import R2d


parser = argparse.ArgumentParser(description='NS Training')
parser.add_argument('--gpu', type=int, default=0, help='GPU ID')
parser.add_argument('--ntrain', type=int, default=90, help='Number of training samples')
parser.add_argument('--ntest', type=int, default=10, help='Number of test samples')
parser.add_argument('--modes', type=int, default=20, help='Number of modes')
parser.add_argument('--modes1', type=int, default=5, help='Number of modes1')
parser.add_argument('--modes2', type=int, default=5, help='Number of modes2')
parser.add_argument('--width', type=int, default=128, help='Model width')
parser.add_argument('--batch-size', type=int, default=40, help='Batch size')
parser.add_argument('--epochs', type=int, default=3, help='Number of epochs')
parser.add_argument('--lr', type=float, default=0.001, help='Learning rate')
parser.add_argument('--scheduler-step', type=int, default=10, help='Scheduler step size')
parser.add_argument('--scheduler-gamma', type=float, default=0.5, help='Scheduler gamma')
parser.add_argument('--s', type=int, default=128, help='Spatial resolution')
parser.add_argument('--t-in', type=int, default=100, help='T_in parameter')
parser.add_argument('--t', type=int, default=400, help='T parameter')
args = parser.parse_args()

ntrain = args.ntrain
ntest = args.ntest

modes = args.modes
modes1 = args.modes1
modes2 = args.modes2
width = args.width

in_dim = 1
out_dim = 1

batch_size = args.batch_size
epochs = args.epochs
learning_rate = args.lr
scheduler_step = args.scheduler_step
scheduler_gamma = args.scheduler_gamma

batch_size_train = args.batch_size
batch_size_vali = args.batch_size

loss_k = 0
loss_group = True

print(epochs, learning_rate, scheduler_step, scheduler_gamma)

sub = 1
S = args.s

T_in = args.t_in
T = args.t
T_out = T_in + T
step = 1

t1 = default_timer()
data = np.load('/code/2D_NS_Re5000.npy')
data = torch.tensor(data, dtype=torch.float)[..., ::sub, ::sub]

train_a = data[:ntrain,T_in-1:T_out-1].reshape(ntrain*T, S, S)
train_u = data[:ntrain,T_in:T_out].reshape(ntrain*T, S, S)

test_a = data[-ntest:,T_in-1:T_out-1].reshape(ntest*T, S, S)
test_u = data[-ntest:,T_in:T_out].reshape(ntest*T, S, S)

assert (S == train_u.shape[2])

train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(train_a, train_u), batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(test_a, test_u), batch_size=batch_size, shuffle=False)

t2 = default_timer()

print('preprocessing finished, time used:', t2-t1)
device = torch.device(f'cuda:{args.gpu}')

model = R2d(in_dim, out_dim, S, modes, width).to(device)

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

total_params = count_parameters(model)
all_params = sum(p.numel() for p in model.parameters())

scale_x = model.conv1.conv0.scale_x.item()
scale_r1 = model.conv1.conv0.scale_r1.item()
scale_r2 = model.conv1.conv0.scale_r2.item()

optimizer = Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=scheduler_step, gamma=scheduler_gamma)
start_time = time.time()
myloss = LpLoss(size_average=False)

train_error = np.zeros((epochs, 1))
train_loss = np.zeros((epochs, 1))
vali_error = np.zeros((epochs, 1))
vali_loss = np.zeros((epochs, 1))
for ep in range(epochs):
    model.train()
    t1 = default_timer()
    train_mse = 0
    train_l2 = 0
    n_train = 0
    for x, y in train_loader:
        x = x.to(device).view(batch_size, S, S, in_dim)
        y = y.to(device).view(batch_size, S, S, out_dim)

        optimizer.zero_grad()
        out = model(x)

        mse = F.mse_loss(out.view(batch_size, -1), y.view(batch_size, -1), reduction='mean')
        l2 = myloss(out.view(-1, train_a.shape[1], train_a.shape[2]), y)
        l2.backward()

        optimizer.step()
        train_mse += mse.item()
        train_l2 += l2.item()
        n_train += 1

    scheduler.step()
    model.eval()
    vali_mse = 0.0
    vali_l2 = 0.0
    with torch.no_grad():
        n_vali = 0
        for x, y in test_loader:
            x = x.to(device).view(batch_size, S, S, in_dim)
            y = y.to(device).view(batch_size, S, S, out_dim)
            out = model(x)
            mse = F.mse_loss(out.view(batch_size, -1), y.view(batch_size, -1), reduction='mean')
            vali_l2 += myloss(out.view(-1, test_a.shape[1], test_a.shape[2]), y).item()
            vali_mse += mse.item()
            n_vali += 1

    train_mse /= n_train
    vali_mse /= n_vali
    train_l2 /= n_train
    vali_l2 /= n_vali
    train_error[ep, 0] = train_l2
    vali_error[ep, 0] = vali_l2
    train_loss[ep, 0] = train_mse
    vali_loss[ep, 0] = vali_mse
    t2 = default_timer()
    print("Epoch: %d, time: %.3f, Train Loss: %.3e,Vali Loss: %.3e, Train l2: %.4f, Vali l2: %.4f" % (
    ep, t2 - t1, train_mse, vali_mse, train_l2, vali_l2))
elapsed = time.time() - start_time

peak_memory_mb = torch.cuda.max_memory_allocated() / (1024**2)
current_memory_mb = torch.cuda.memory_allocated() / (1024**2)

scale_x_final = model.conv1.conv0.scale_x.item()
scale_r1_final = model.conv1.conv0.scale_r1.item()
scale_r2_final = model.conv1.conv0.scale_r2.item()

test_size = ntest * T
pred_u = torch.zeros(test_size, S, S).cpu()
true_u = torch.zeros(test_size, S, S).cpu()

model.eval()

print("Generating predictions for validation set...")
with torch.no_grad():
    batch_idx = 0
    for x, y in test_loader:
        x = x.to(device).view(batch_size, S, S, in_dim)
        y = y.to(device).view(batch_size, S, S, out_dim)
        out = model(x)
        
        curr_batch_size = x.shape[0]
        start_idx = batch_idx * batch_size
        end_idx = min(start_idx + curr_batch_size, test_size)
        
        pred_u[start_idx:end_idx] = out.view(-1, S, S).cpu()
        true_u[start_idx:end_idx] = y.view(-1, S, S).cpu()
        
        batch_idx += 1

pred_u_np = pred_u.numpy()
true_u_np = true_u.numpy()

final_mse = F.mse_loss(pred_u, true_u, reduction='mean').item()
final_l2 = myloss(pred_u, true_u).item()
print(f"Final validation MSE: {final_mse:.6e}")
print(f"Final validation L2: {final_l2:.6f}")

