import argparse
import os
parser = argparse.ArgumentParser('Training FNO')
parser.add_argument('--lr',type=float, default=1e-3)
parser.add_argument('--epochs',type=int, default=500)
parser.add_argument('--weight_decay',type=float,default=1e-4)
parser.add_argument("--n1", type=int, default=32)
parser.add_argument("--n2", type=int, default=32)
parser.add_argument("--width", type=int, default=32, help="Width")
parser.add_argument('--batch-size',type=int, default=32)
parser.add_argument("--use_tb", type=int, default=0, help="Use TensorBoard: 1 for True, 0 for False")
parser.add_argument("--gpu", type=str, default='1', help="GPU index to use")
parser.add_argument('--max_grad_norm',type=float, default=1)
parser.add_argument('--train_downsample',type=int,default=1)
parser.add_argument('--test_downsample',type=int,default=1)
parser.add_argument('--dropout',type=float, default=0.)
parser.add_argument('--dropout_type',type=str, default="GD", help="Dropout Type: MC for typical dropout, GD for Gaussian dropout")
parser.add_argument("--model", type=str, default='ofno', help="GPU index to use")
args = parser.parse_args()

os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
import torch
from timeit import default_timer
from utilities3 import *
from AM_FNO import  FNO2d, FNO2dMLP
from FNOs import UFNO2d , FNOFactorizedMesh2D, vannilaFNO2d
torch.manual_seed(42)
np.random.seed(42)

base_path = ""
flnm = "2D_CFD_Rand_M0.1_Eta0.01_Zeta0.01_periodic_128_Train.hdf5"
reduced_resolution = 2
reduced_batch = 5
reduced_resolution_t = 1
initial_step = 10
t_train = 21
ntrain = 1800
ntest = 200
T = 11

train_data = FNODatasetSingle(flnm,
                        reduced_resolution=reduced_resolution,
                        reduced_resolution_t=reduced_resolution_t,
                        reduced_batch=reduced_batch,
                        initial_step=initial_step,
                        saved_folder = base_path
                        )
val_data = FNODatasetSingle(flnm,
                        reduced_resolution=reduced_resolution,
                        reduced_resolution_t=reduced_resolution_t,
                        reduced_batch=reduced_batch,
                        initial_step=initial_step,
                        if_test=True,
                        saved_folder = base_path
                        )
train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size,
                                               num_workers=2, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_data, batch_size=args.batch_size,
                                             num_workers=2, shuffle=False)
ntrain = len(train_data)
ntest = len(val_data)
print(ntrain, ntest)
if args.model == "ofno":
    model = FNO2d( n1 = args.n1, n2 = args.n2, width = args.width, input_dim=initial_step*4, padding=2, output_dim = 4, H = 64, W = 64).cuda()
elif args.model == "ufno":
    model = UFNO2d(12,12,32,input_dim=initial_step*4, output_dim = 4).cuda()
elif args.model == "ffno":
    model = FNOFactorizedMesh2D(modes_x=12, modes_y=12, width=32,input_dim=initial_step*4, output_dim = 4).cuda()
elif args.model == "fno":
    model = vannilaFNO2d(12,12,32,input_dim=initial_step*4, output_dim = 4).cuda()
elif args.model == "fnoall":
    model = vannilaFNO2d(66,34,32,input_dim=initial_step*4, output_dim = 4).cuda()  
elif args.model == "fnomlp":
    model = FNO2dMLP(n1 = 64, n2 = 64, width = args.width, input_dim=initial_step*4, padding=2, output_dim = 4, H = 64, W = 64).cuda()  
print(count_params(model))
print(args)
print(model)
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs*len(train_loader))
myloss = LpLoss(size_average=False)
for ep in range(args.epochs):
    model.train()
    t1 = default_timer()
    train_l2_step = 0
    train_l2_full = 0
    for xx, yy, grid in train_loader:

        loss = 0
        bsz = xx.shape[0]
        xx = xx.to(device)
        yy = yy.to(device)
        grid = grid.to(device)

        pred = yy[..., :initial_step, :]
        # Extract shape of the input tensor for reshaping (i.e. stacking the
        # time and channels dimension together)
        inp_shape = list(xx.shape)
        inp_shape = inp_shape[:-2]
        inp_shape.append(-1)

        for t in range(initial_step, t_train):

            inp = xx.reshape(inp_shape)
            y = yy[..., t:t+1, :]
            im = model(inp, grid = grid).unsqueeze(-2)
            loss += myloss(im.reshape(bsz, -1), y.reshape(bsz, -1))

            pred = torch.cat((pred, im), -2)

            xx = torch.cat((xx[..., 1:, :], im), dim=-2)

        train_l2_step += loss.item()
        _yy = yy[..., :t_train, :]
        l2_full = myloss(pred.reshape(bsz, -1), _yy.reshape(bsz, -1))
        train_l2_full += l2_full.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

    test_l2_step = 0
    test_l2_full = 0
    with torch.no_grad():
        for xx, yy, grid in val_loader:
            loss = 0
            xx = xx.to(device)
            yy = yy.to(device)
            grid = grid.to(device)
            bsz = xx.shape[0]
            pred = yy[..., :initial_step, :]
            inp_shape = list(xx.shape)
            inp_shape = inp_shape[:-2]
            inp_shape.append(-1)
            
            for t in range(initial_step, yy.shape[-2]):
                inp = xx.reshape(inp_shape)
                y = yy[..., t:t+1, :]
                im = model(inp, grid = grid).unsqueeze(-2)

                loss += myloss(im.reshape(bsz, -1), y.reshape(bsz, -1))

                pred = torch.cat((pred, im), -2)

                xx = torch.cat((xx[..., 1:, :], im), dim=-2)

            test_l2_step += loss.item()
            _pred = pred[..., initial_step:t_train, :]
            _yy = yy[..., initial_step:t_train, :]
            test_l2_full += myloss(_pred.reshape(bsz, -1), _yy.reshape(bsz, -1)).item()

    t2 = default_timer()

    print(ep, t2 - t1, train_l2_step / ntrain / T, train_l2_full / ntrain, test_l2_step / ntest / T ,
          test_l2_full / ntest)
    
#torch.save(model.state_dict(), "".format(args.model))

