import sys
import torch
import numpy as np
import pickle
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import functional

import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data.distributed import DistributedSampler

import wandb

import operator
from functools import reduce
from functools import partial
import torch.distributed as dist

from timeit import default_timer

from pdebench.models.fno.fno import FNO1d, FNO2d, FNO3d, FNO2d_MAE
from pdebench.models.fno.utils import FNODatasetSingle, FNODatasetMult, LpLoss, compute_grad_norm
from pdebench.models.metrics import metrics

import random
import os

def run_training(if_training,
                 continue_training,
                 num_workers,
                 modes,
                 width,
                 initial_step,
                 t_train,
                 num_channels,
                 batch_size,
                 epochs,
                 learning_rate,
                 scheduler_step,
                 scheduler_gamma,
                 model_update,
                 flnm,
                 single_file,
                 reduced_resolution,
                 reduced_resolution_t,
                 reduced_batch,
                 plot,
                 channel_plot,
                 x_min,
                 x_max,
                 y_min,
                 y_max,
                 t_min,
                 t_max,
                 base_path='../data/',
                 model_arch='fno2dMAE',
                 training_type='autoregressive',
                 blur=[0,0],
                 mask_ratio=0,
                 seed=0
                 ):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    
    logs = {}
    
    print(f'Epochs = {epochs}, learning rate = {learning_rate}, scheduler step = {scheduler_step}, scheduler gamma = {scheduler_gamma}')
        
    ################################################################
    # load data
    ################################################################
    initial_step = 1
    
    if single_file:
        # filename
        model_name = 'pretrain' + f'_b{blur[0]}b{blur[1]}_m{mask_ratio}_' + model_arch
        print("FNODatasetSingle")

        # Initialize the dataset and dataloader
        train_data = FNODatasetSingle(flnm,
                                reduced_resolution=reduced_resolution,
                                reduced_resolution_t=reduced_resolution_t,
                                reduced_batch=reduced_batch,
                                initial_step=initial_step,
                                saved_folder = base_path,
                                pretrain=True, 
                                mask_ratio=mask_ratio,
                                blur=blur
                                )
        val_data = FNODatasetSingle(flnm,
                              reduced_resolution=reduced_resolution,
                              reduced_resolution_t=reduced_resolution_t,
                              reduced_batch=reduced_batch,
                              initial_step=initial_step,
                              if_test=True,
                              saved_folder = base_path,
                              pretrain=True,
                              mask_ratio=mask_ratio,
                              blur=blur
                              )

    # train_sampler = DistributedSampler(train_data, shuffle=True) 
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size,
                                               num_workers=num_workers, shuffle=True) #sampler=train_sampler)
    
    # val_sampler = DistributedSampler(val_data, shuffle=False)
    val_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size,
                                             num_workers=num_workers, shuffle=False) #sampler=val_sampler)
    
    xx, xx_blur, grid, masks = next(iter(train_loader))
    
    ################################################################
    # training and evaluation
    ################################################################
    
    _, _data, _, _ = next(iter(val_loader))
    dimensions = len(_data.shape)
    print('Spatial Dimension', dimensions - 2)
    if model_arch == 'fno2dMAE':
        model = FNO2d_MAE(num_channels=num_channels,
                    width=width,
                    modes1=modes,
                    modes2=modes,
                    initial_step=initial_step).to(device)
    else:
        model = FNO2d(num_channels=num_channels,
                    width=width,
                    modes1=modes,
                    modes2=modes,
                    initial_step=initial_step).to(device)
    print(model)
        
    # Set maximum time step of the data to train
    if t_train > _data.shape[-2]:
        t_train = _data.shape[-2]

    model_path = model_name + ".pt"
    
    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f'Total parameters = {total_params}')
    
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=scheduler_step, gamma=scheduler_gamma)
    
    # loss_fn = nn.MSELoss(reduction="mean").to(device)


    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    model.to(device)
    
    for xx, yy, grid, masks in train_loader:
        xx = xx.to(device)
        masks = masks.to(device)
        grid = grid.to(device)
        
        pred = model(xx, grid, masks)
        
        torch.save({'prediction': pred.cpu(),
                    'source': xx.cpu(),
                    'mask': masks.cpu(),
                    'target': yy.cpu()}, f'model.pt')
        break
                
            
if __name__ == "__main__":    
    run_training()
    print("Done.")

