import argparse
import os
parser = argparse.ArgumentParser('Training FNO')
parser.add_argument('--lr',type=float, default=1e-3)
parser.add_argument('--epochs',type=int, default=500)
parser.add_argument('--weight_decay',type=float,default=1e-4)
parser.add_argument("--n1", type=int, default=64)
parser.add_argument("--n2", type=int, default=64)
parser.add_argument("--width", type=int, default=32, help="Width")
parser.add_argument('--batch-size',type=int, default=8)
parser.add_argument("--use_tb", type=int, default=0, help="Use TensorBoard: 1 for True, 0 for False")
parser.add_argument("--gpu", type=str, default='4', help="GPU index to use")
parser.add_argument('--max_grad_norm',type=float, default=1)
parser.add_argument('--downsample',type=int,default=5)
parser.add_argument('--ntrain',type=int, default=1000)
parser.add_argument('--dropout',type=float, default=0.)
parser.add_argument('--dropout_type',type=str, default="GD", help="Dropout Type: MC for typical dropout, GD for Gaussian dropout")
args = parser.parse_args()

os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
import torch
from timeit import default_timer
from utilities3 import *
from AM_FNO import  FNO2d, FNO2dMLP
from FNOs import UFNO2d , FNOFactorizedMesh2D
torch.manual_seed(42)
np.random.seed(42)


INPUT_X = ''
INPUT_Y = ''
OUTPUT_Sigma = ''

ntrain = 1000
ntest = 200
N = 1200

r1 = 1
r2 = 1
s1 = int(((129 - 1) / r1) + 1)
s2 = int(((129 - 1) / r2) + 1)


inputX = np.load(INPUT_X)
inputX = torch.tensor(inputX, dtype=torch.float)
inputY = np.load(INPUT_Y)
inputY = torch.tensor(inputY, dtype=torch.float)
input = torch.stack([inputX, inputY], dim=-1)

output = np.load(OUTPUT_Sigma)[:, 0]
output = torch.tensor(output, dtype=torch.float)

x_train = input[:N][:ntrain, ::r1, ::r2][:, :s1, :s2]
y_train = output[:N][:ntrain, ::r1, ::r2][:, :s1, :s2]
x_test = input[:N][-ntest:, ::r1, ::r2][:, :s1, :s2]
y_test = output[:N][-ntest:, ::r1, ::r2][:, :s1, :s2]
x_train = x_train.reshape(ntrain, s1, s2, 2)
x_test = x_test.reshape(ntest, s1, s2, 2)
print(x_train.shape,y_train.shape)
train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, y_train), batch_size=args.batch_size,
                                           shuffle=True)
test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=args.batch_size,
                                          shuffle=False)

print(args)


model = FNO2dMLP( n1 = args.n1, n2 = args.n2, width = args.width, padding=8, input_dim = 2, mlp_dropout = 0., H = 129, W = 129).cuda()

#model = UFNO2d(20,20,32,input_dim=2).cuda()
#model = FNOFactorizedMesh2D(modes_x=20, modes_y=20, width=32, input_dim=4).cuda()
print(count_params(model))
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
#optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs*len(train_loader))
#scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=args.lr, epochs=epochs,steps_per_epoch=len(train_loader))
print(model)

myloss = LpLoss(size_average=False)

for ep in range(args.epochs):
    model.train()
    t1 = default_timer()
    train_l2 = 0
    for x, y in train_loader:
        x, y = x.cuda(), y.cuda()
        bsz = x.shape[0]
        optimizer.zero_grad()
        out = model(x)

        loss = myloss(out.view(bsz, -1), y.view(bsz, -1))
        loss.backward()

        optimizer.step()
        scheduler.step()
        train_l2 += loss.item()

    

    model.eval()
    test_l2 = 0.0
    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.cuda(), y.cuda()
            bsz = x.shape[0]
            out = model(x)

            test_l2 += myloss(out.view(bsz, -1), y.view(bsz, -1)).item()

    train_l2 /= ntrain
    test_l2 /= ntest

    t2 = default_timer()
    print(ep, t2 - t1, train_l2, test_l2)


#torch.save(model.state_dict(), "".format(args.dropout_type, args.dropout, args.n1, args.n2))


