import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
import numpy as np

import os
import os.path as osp
from tqdm import tqdm
import shutil 
import sys
import argparse
import pickle

from utils import *
from models import *




    
    
def get_args(argv=None):
    parser = argparse.ArgumentParser(description = 'Put your hyperparameters')
    parser.add_argument('name', type=str, help='experiments name')
    parser.add_argument('--load_path', default=None, type=str, help='path of directory to resume the training')
    parser.add_argument('--model', default=None, type=str, help='model name')    
    parser.add_argument('--data', default='identity', type=str, help='Data name')
    parser.add_argument('--gpu_idx', default=0, type=int, help='index of gpu which you want to use')
    parser.add_argument('--multgpu', action='store_true', help='whether multiple gpu or not')    
    parser.add_argument('--seed', type=int, help='random seed')
    
    parser.add_argument('--batch', default=10000, type=int, help = 'batch size')
    parser.add_argument('--epochs', default=100000, type=int, help = 'Number of Epochs')
    parser.add_argument('--lr', default=1e-3, type=float, help='learning rate')
    parser.add_argument('--wd', default=1e-5, type=float, help='weight decay')
    parser.add_argument('--sche_type', default=None, type=str, help='type of scheduler: steplr, inversetime')  
    parser.add_argument('--step_size', default=100, type=int, help='scheduler step size')
    parser.add_argument('--gamma', default=0.99, type=float, help='scheduler factor')

    
    parser.add_argument('--d_in', default=1, type=int, help='dimension of input for target function')
    parser.add_argument('--d_out', default=1, type=int, help='dimension of output for target function')
    parser.add_argument('--n_layer', type=int, help='number of fourier layer in FNO model')
    parser.add_argument('--n_mode', type=int, help='number of fourier layer in FNO model')
    parser.add_argument('--n_sensor', default=100, type=int, help='number of sensor points')

    
    return parser.parse_args(argv)

if __name__=="__main__":
    args = get_args()
    print(args)
    NAME = args.name
    print(sys.argv)
    if args.load_path is None:
        PATH = 'results/'
        PATH = os.path.join(PATH, NAME)
        os.mkdir(PATH)
    else:
        PATH = args.load_path
        args = torch.load(os.path.join(args.load_path, 'args.bin'))
        args.load_path = PATH
        args.name = NAME
        PATH = os.path.join(PATH, NAME)
        os.mkdir(PATH)
        
    shutil.copy(sys.argv[0], os.path.join(PATH, 'code.py'))
    
    # Set seed
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    
    
    if args.multgpu:
        num_gpu = torch.cuda.device_count()
    else:
        num_gpu = 1
        os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_idx)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    ## setting
    batch_size = args.batch
    epochs = args.epochs
    lr = args.lr
    w_decay = args.wd
    scheduler_type = args.sche_type
    schedule_step_size = args.step_size
    schedule_gamma = args.gamma
    
    d_in = args.d_in
    d_out = args.d_out
    num_sensor = args.n_sensor
    num_layer=args.n_layer
    num_mode=args.n_mode
    
    torch.save(args, os.path.join(PATH, 'args.bin'))

    ## load dataset
    N_train=1000
    N_test=200
    data_name=args.data+'_N'+str(N_train)+'_M'+str(num_sensor)+'.pickle'
    with open("./data/data_fno/"+'fno_'+data_name,"rb") as fr:
        raw_set= pickle.load(fr)
    raw_set['train_X'] = raw_set['train_X'].reshape(N_train,num_sensor,1)
    raw_set['test_X'] = raw_set['test_X'].reshape(N_test,num_sensor,1)
    
    train_dataset = TensorDataset(raw_set['train_X'].cuda().float(), raw_set['train_Y'].cuda().float())
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_dataset = TensorDataset(raw_set['test_X'].cuda().float(), raw_set['test_Y'].cuda().float())
    test_loader = DataLoader(test_dataset, batch_size=raw_set['test_X'].shape[0], shuffle=False)
    
    ## model
    if args.model=='fno':
        model=FNO1d(num_layer=num_layer, modes=num_mode)
        
    model=model.cuda()
    if num_gpu> 1:
        print("Let's use", num_gpu, "GPUs!")
        model = nn.DataParallel(model).cuda()

    optimizer=torch.optim.Adam(model.parameters(), lr=lr, weight_decay=w_decay)
    
    if scheduler_type=='steplr':
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=schedule_step_size, gamma=schedule_gamma)
    elif scheduler_type=='inversetime':
        fcn = lambda x: 1./(1. + schedule_gamma*x/schedule_step_size)
        scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=fcn)
    
    
    os.makedirs(os.path.join(PATH, 'weights'))
    
    ## model trianing
    train_losses=[]
    test_losses=[]  
    test_rels=[]
    
    pbar = tqdm(total=epochs, file=sys.stdout)
    for epoch in tqdm(range(1,epochs+1)):
        model.train()
        train_loss=AverageMeter()
        test_loss=AverageMeter()
        test_rel=AverageMeter()
        for batch in train_loader:
            x,y=batch

            loss=mse_error(model(x).reshape(-1), y.reshape(-1))
            #zero gradients, backward pass, and update weights
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss.update(loss.item(), y.shape[0])
        model.eval()
        with torch.no_grad():
            for batch in test_loader:
                x,y=batch
                predict=model(x).reshape(-1)
                error_test=mse_error(predict, y.reshape(-1))
                test_loss.update(error_test.item(), y.shape[0])
                for prediction, real in zip(predict.reshape(-1,args.n_sensor),y.reshape(-1,args.n_sensor)):
                    test_rel.update(rel_L2_error(prediction,real), 1)
        
        train_losses.append(train_loss.avg)
        test_losses.append(test_loss.avg)
        test_rels.append(test_rel.avg)

        pbar.set_description("###### Epoch : %d, Loss_train : %.5f, Loss_test : %.5f, rel_test : %.5f ######"%(epoch, train_losses[-1], test_losses[-1], test_rels[-1]))
        if scheduler_type!=None:
            scheduler.step()
        pbar.update()
        
        if epoch%5000==0:
            
            torch.save(model.state_dict(),os.path.join(PATH, 'weight_epoch_{}.bin'.format(epoch)))

        torch.save(model.state_dict(),os.path.join(PATH, 'weight.bin'))
        torch.save({'train_loss':train_losses, 'test_loss':test_losses, 'test_rel':test_rels}, os.path.join(PATH, 'loss.bin'))
        torch.save({'epoch':epoch,
                    'optimizer':optimizer.state_dict(),
                    'scheduler':scheduler.state_dict()}, os.path.join(PATH, 'checkpoint.bin'))
