import sys
import torch
import numpy as np
import pickle
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import functional

import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data.distributed import DistributedSampler

import wandb

import operator
from functools import reduce
from functools import partial
import torch.distributed as dist

from timeit import default_timer

from pdebench.models.fno.fno import FNO1d, FNO2d, FNO3d, FNO2d_MAE
from pdebench.models.fno.utils import FNODatasetSingle, FNODatasetMult, LpLoss, compute_grad_norm
from pdebench.models.metrics import metrics

import random
import os

def run_training(if_training,
                 continue_training,
                 num_workers,
                 modes,
                 width,
                 initial_step,
                 t_train,
                 num_channels,
                 batch_size,
                 epochs,
                 learning_rate,
                 scheduler_step,
                 scheduler_gamma,
                 model_update,
                 flnm,
                 single_file,
                 reduced_resolution,
                 reduced_resolution_t,
                 reduced_batch,
                 plot,
                 channel_plot,
                 x_min,
                 x_max,
                 y_min,
                 y_max,
                 t_min,
                 t_max,
                 base_path='../data/',
                 model_arch='fno2dMAE',
                 training_type='autoregressive',
                 blur=[0,0],
                 mask_ratio=0,
                 seed=0
                 ):
    world_rank = 0
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    # if world_size > 0:
    #     torch.cuda.manual_seed_all(seed)
    
    logs = {}
    
    print(f'Epochs = {epochs}, learning rate = {learning_rate}, scheduler step = {scheduler_step}, scheduler gamma = {scheduler_gamma}')
        
    ################################################################
    # load data
    ################################################################
    initial_step = 1
    
    if single_file:
        # filename
        model_name = 'pretrain' + f'_b{blur[0]}b{blur[1]}_m{mask_ratio}_' + model_arch
        print("FNODatasetSingle")

        # Initialize the dataset and dataloader
        train_data = FNODatasetSingle(flnm,
                                reduced_resolution=reduced_resolution,
                                reduced_resolution_t=reduced_resolution_t,
                                reduced_batch=reduced_batch,
                                initial_step=initial_step,
                                saved_folder = base_path,
                                pretrain=True, 
                                mask_ratio=mask_ratio,
                                blur=blur
                                )
        val_data = FNODatasetSingle(flnm,
                              reduced_resolution=reduced_resolution,
                              reduced_resolution_t=reduced_resolution_t,
                              reduced_batch=reduced_batch,
                              initial_step=initial_step,
                              if_test=True,
                              saved_folder = base_path,
                              pretrain=True,
                              mask_ratio=mask_ratio,
                              blur=blur
                              )

    # train_sampler = DistributedSampler(train_data, shuffle=True) 
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size,
                                               num_workers=num_workers, shuffle=True) #sampler=train_sampler)
    
    # val_sampler = DistributedSampler(val_data, shuffle=False)
    val_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size,
                                             num_workers=num_workers, shuffle=False) #sampler=val_sampler)
    
    xx, xx_blur, grid, masks = next(iter(train_loader))
    
    wandb.init(
            entity='entity_name',
            name=model_name,
            # Set the project where this run will be logged
            project="2dDiffReact",
            # Track hyperparameters and run metadata
            config={
                "learning_rate": learning_rate,
                "epochs": epochs,
                "scheduler_step": scheduler_step,
                "scheduler_gamma": scheduler_gamma,
                "seed": seed,
                "blur": f'{blur[0]}-{blur[1]}',
                "mask_ratio": mask_ratio
            })
    
    
    ################################################################
    # training and evaluation
    ################################################################
    
    _, _data, _, _ = next(iter(val_loader))
    dimensions = len(_data.shape)
    print('Spatial Dimension', dimensions - 2)
    if model_arch == 'fno2dMAE':
        model = FNO2d_MAE(num_channels=num_channels,
                    width=width,
                    modes1=modes,
                    modes2=modes,
                    initial_step=initial_step).to(device)
    else:
        model = FNO2d(num_channels=num_channels,
                    width=width,
                    modes1=modes,
                    modes2=modes,
                    initial_step=initial_step).to(device)
    print(model)
        
    # Set maximum time step of the data to train
    if t_train > _data.shape[-2]:
        t_train = _data.shape[-2]

    model_path = model_name + ".pt"
    
    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f'Total parameters = {total_params}')
    
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=scheduler_step, gamma=scheduler_gamma)
    
    # loss_fn = nn.MSELoss(reduction="mean").to(device)
    lploss = LpLoss(size_average=True)
    loss_val_min = np.infty
    loss_train_min = np.infty
    
    start_epoch = 0

    if not if_training:
        checkpoint = torch.load(model_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        model.eval()
        Lx, Ly, Lz = 1., 1., 1.
        errs = metrics(val_loader, model, Lx, Ly, Lz, plot, channel_plot,
                       model_name, x_min, x_max, y_min, y_max,
                       t_min, t_max, initial_step=initial_step)
        pickle.dump(errs, open(model_name+'.pickle', "wb"))
        
        return

    # If desired, restore the network by loading the weights saved in the .pt
    # file
    if continue_training:
        print('Restoring model (that is the network\'s weights) from file...')
        checkpoint = torch.load(model_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        model.train()
        
        # Load optimizer state dict
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to(device)
                    
        start_epoch = checkpoint['epoch']
        loss_val_min = checkpoint['loss']
        
    model.to(device)
    
    for ep in range(start_epoch, epochs):
        model.train()
        t1 = default_timer()
        train_l2_full = 0
        grad_full = 0
        for xx, xx_blur, grid, masks in train_loader:    
            xx = xx.to(device)
            xx_blur = xx_blur.contiguous().to(device)
            masks = masks.to(device)
            grid = grid.to(device)
            
            model.zero_grad()
            optimizer.zero_grad()
            
            pred = model(xx_blur, grid, masks)
                 
            if mask_ratio > 0:
                loss = lploss((pred*(1-masks)), (xx*(1-masks)))
            else:
                loss = lploss(pred, xx)
                
            train_l2_full += loss.item()

            loss.backward()
            optimizer.step()
            
            grad_norm = compute_grad_norm(model)
            grad_full += grad_norm
            
        if  train_l2_full < loss_train_min and world_rank == 0:
            loss_train_min = train_l2_full
            torch.save({
                'epoch': ep,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss_train_min
                }, model_path)
        logs['train_loss'] = train_l2_full/len(train_loader)
        logs['best_train_loss'] = loss_train_min
        logs['grad_norm'] = grad_full/len(train_loader)

        if ep % model_update == 0:
            val_l2_full = 0
            model.eval()
            for xx, _, grid, masks in val_loader:
                xx = xx.to(device)
                masks = masks.to(device)
                grid = grid.to(device)
                
                pred = model(xx, grid, masks)
                
                if mask_ratio > 0:
                    loss = lploss((pred*(1-masks)), (xx*(1-masks)))
                else:
                    loss = lploss(pred, xx)
    
                val_l2_full += loss.item()
            
            if  val_l2_full < loss_val_min and world_rank == 0:
                loss_val_min = val_l2_full
                torch.save({
                    'epoch': ep,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': loss_val_min
                    }, model_path[:-3]+'_val.pt')
            
            logs['val_loss'] = val_l2_full/len(val_loader)
            logs['best_val_loss'] = loss_val_min
                    
        model.train()
        
        if world_rank == 0:
            wandb.log(logs, step=ep+1)
            
        t2 = default_timer()
        scheduler.step()
    wandb.finish()
            
if __name__ == "__main__":    
    run_training()
    print("Done.")

