import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
import numpy as np

import os
import os.path as osp
from tqdm import tqdm
import shutil 
import sys
import argparse
import pickle

from utils import *
from models import *


def get_args(argv=None):
    parser = argparse.ArgumentParser(description = 'Put your hyperparameters')
    parser.add_argument('name', type=str, help='experiments name')
    parser.add_argument('--load_path', default=None, type=str, help='path of directory to resume the training')
    parser.add_argument('--model', default=None, type=str, help='model name')    
    parser.add_argument('--data', default='identity', type=str, help='Data name')
    parser.add_argument('--gpu_idx', default=0, type=int, help='index of gpu which you want to use')
    parser.add_argument('--multgpu', action='store_true', help='whether multiple gpu or not')    
    parser.add_argument('--seed', type=int, help='random seed')
    
    parser.add_argument('--batch', default=10000, type=int, help = 'batch size')
    parser.add_argument('--epochs', default=100000, type=int, help = 'Number of Epochs')
    parser.add_argument('--lr', default=1e-3, type=float, help='learning rate')
    parser.add_argument('--wd', default=1e-5, type=float, help='weight decay')
    parser.add_argument('--sche_type', default=None, type=str, help='type of scheduler: steplr, inversetime')  
    parser.add_argument('--step_size', default=100, type=int, help='scheduler step size')
    parser.add_argument('--gamma', default=0.99, type=float, help='scheduler factor')

    
    parser.add_argument('--d_in', default=1, type=int, help='dimension of input for target function')
    parser.add_argument('--d_out', default=1, type=int, help='dimension of output for target function')
    parser.add_argument('--n_sensor', default=100, type=int, help='number of sensor points')    
    parser.add_argument('--d_target', default=2, type=int, help='depth of target network (except basis)')
    parser.add_argument('--w_target', default=30, type=int, help='width of target network')
    parser.add_argument('--a_target', default='relu', type=str, help='activation of target network')
    parser.add_argument('--d_hyper', default=2, type=int, help='depth of hyper(branch) network (except basis)')
    parser.add_argument('--w_hyper', default=30, type=int, help='width of hyper(branch) network')
    parser.add_argument('--a_hyper', default='relu', type=str, help='activation of hyper(branch) network')
    parser.add_argument('--n_basis', default=100, type=int, help='number of basis (width of last layer in target network)')   
    
    parser.add_argument('--chunk_in', default=100, type=int, help='number of inputs for one chunk')  
    parser.add_argument('--chunk_out', default=100, type=int, help='number of outputs for one chunk')  
    
    return parser.parse_args(argv)

if __name__=="__main__":
    args = get_args()
    print(args)
    NAME = args.name
    print(sys.argv)
    if args.load_path is None:
        PATH = 'results/'
        PATH = os.path.join(PATH, NAME)
        os.mkdir(PATH)
    else:
        PATH = args.load_path
        args = torch.load(os.path.join(args.load_path, 'args.bin'))
        args.load_path = PATH
        args.name = NAME
        PATH = os.path.join(PATH, NAME)
        os.mkdir(PATH)
        
    shutil.copy(sys.argv[0], os.path.join(PATH, 'code.py'))
    
    # Set seed
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    
    
    if args.multgpu:
        num_gpu = torch.cuda.device_count()
    else:
        num_gpu = 1
        os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_idx)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    ## setting
    batch_size = args.batch
    epochs = args.epochs
    lr = args.lr
    w_decay = args.wd
    scheduler_type = args.sche_type
    schedule_step_size = args.step_size
    schedule_gamma = args.gamma
    
    d_in = args.d_in
    d_out = args.d_out
    num_sensor = args.n_sensor
    d_target = args.d_target
    w_target = args.w_target
    a_target = args.a_target
    d_hyper = args.d_hyper
    w_hyper = args.w_hyper
    a_hyper = args.a_hyper
    num_basis = args.n_basis
    num_chunk_in = args.chunk_in
    num_chunk_out = args.chunk_out
    
    torch.save(args, os.path.join(PATH, 'args.bin'))

    ## load dataset
    if args.data=='shallow':
        N_train=100
    else:
        N_train=1000
    data_name=args.data+'_N'+str(N_train)+'_M'+str(num_sensor)+'.pickle'
    with open("./data/"+data_name,"rb") as fr:
        raw_set= pickle.load(fr)
    train_dataset = TensorDataset(raw_set['train_X'].cuda().float(), raw_set['train_Y'].cuda().float())
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_dataset = TensorDataset(raw_set['test_X'].cuda().float(), raw_set['test_Y'].cuda().float())
    test_loader = DataLoader(test_dataset, batch_size=raw_set['test_X'].shape[0], shuffle=False)
    
    ## model
    if args.model=='deeponet':
        model=deeponet(
            depth_trunk=d_target,
            width_trunk=w_target,
            act_trunk=a_target,
            depth_branch=d_hyper,
            width_branch=w_hyper,
            act_branch=a_hyper,    
            num_sensor=num_sensor,
            input_dim=d_in,
            num_basis=num_basis,
            output_dim=d_out)
    elif args.model=='shiftdeeponet':
        model=shiftdeeponet(
            depth_trunk=d_target,
            width_trunk=w_target,
            act_trunk=a_target,
            depth_branch=d_hyper,
            width_branch=w_hyper,
            act_branch=a_hyper,    
            num_sensor=num_sensor,
            input_dim=d_in,
            num_basis=num_basis,
            output_dim=d_out)
    elif args.model=='flexdeeponet':
        model=flexdeeponet(
            depth_trunk=d_target,
            width_trunk=w_target,
            act_trunk=a_target,
            depth_branch=d_hyper,
            width_branch=w_hyper,
            act_branch=a_hyper,    
            num_sensor=num_sensor,
            input_dim=d_in,
            num_basis=num_basis,
            output_dim=d_out)
    elif args.model=='NOMAD':
        model=NOMAD(
            depth_trunk=d_target,
            width_trunk=w_target,
            act_trunk=a_target,
            depth_branch=d_hyper,
            width_branch=w_hyper,
            act_branch=a_hyper,    
            num_sensor=num_sensor,
            input_dim=d_in,
            num_basis=num_basis,
            output_dim=d_out)
    elif args.model=='hyperdeeponet':
        model=hyperdeeponet(
            depth_trunk=d_target,
            width_trunk=w_target,
            act_trunk=a_target,
            depth_hyper=d_hyper,
            width_hyper=w_hyper,
            act_hyper=a_hyper,    
            num_sensor=num_sensor,
            input_dim=d_in,
            num_basis=num_basis,
            output_dim=d_out)
    elif args.model=='chunk':
        model=chunk_hyperdeeponet(
            depth_trunk=d_target,
            width_trunk=w_target,
            act_trunk=a_target,
            depth_hyper=d_hyper,
            width_hyper=w_hyper,
            act_hyper=a_hyper,    
            num_sensor=num_sensor,
            input_dim=d_in,
            num_basis=num_basis,
            output_dim=d_out,
            num_chunk_in=num_chunk_in,
            num_chunk_out=num_chunk_out)
        
    model=model.cuda()
    print('The number of total parameters:', get_n_params(model))
    
    if num_gpu> 1:
        print("Let's use", num_gpu, "GPUs!")
        model = nn.DataParallel(model).cuda()

    optimizer=torch.optim.Adam(model.parameters(), lr=lr, weight_decay=w_decay)
    
    if scheduler_type=='steplr':
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=schedule_step_size, gamma=schedule_gamma)
    elif scheduler_type=='inversetime':
        fcn = lambda x: 1./(1. + schedule_gamma*x/schedule_step_size)
        scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=fcn)
    
    
    os.makedirs(os.path.join(PATH, 'weights'))
    
    ## model trianing
    train_losses=[]
    test_losses=[]  
    test_rels=[]
    
    pbar = tqdm(total=epochs, file=sys.stdout)
    for epoch in tqdm(range(1,epochs+1)):
        model.train()
        train_loss=AverageMeter()
        test_loss=AverageMeter()
        test_rel=AverageMeter()
        for batch in train_loader:
            x,y=batch
            loss=mse_error(model(x[:,num_sensor:],x[:,:num_sensor]).reshape(-1), y)
            #zero gradients, backward pass, and update weights
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss.update(loss.item(), y.shape[0])
        model.eval()
        with torch.no_grad():
            for batch in test_loader:
                x,y=batch
                predict=model(x[:,num_sensor:],x[:,:num_sensor]).reshape(-1)
                error_test=mse_error(predict, y)
                test_loss.update(error_test.item(), y.shape[0])
                for prediction, real in zip(predict.reshape(-1,args.n_sensor),y.reshape(-1,args.n_sensor)):
                    test_rel.update(rel_L2_error(prediction,real), 1)
        
        train_losses.append(train_loss.avg)
        test_losses.append(test_loss.avg)
        test_rels.append(test_rel.avg)

        pbar.set_description("###### Epoch : %d, Loss_train : %.5f, Loss_test : %.5f, rel_test : %.5f ######"%(epoch, train_losses[-1], test_losses[-1], test_rels[-1]))
        if scheduler_type!=None:
            scheduler.step()
        pbar.update()
        
        if epoch%5000==0:
            
            torch.save(model.state_dict(),os.path.join(PATH, 'weight_epoch_{}.bin'.format(epoch)))

        torch.save(model.state_dict(),os.path.join(PATH, 'weight.bin'))
        torch.save({'train_loss':train_losses, 'test_loss':test_losses, 'test_rel':test_rels}, os.path.join(PATH, 'loss.bin'))
        torch.save({'epoch':epoch,
                    'optimizer':optimizer.state_dict(),
                    'scheduler':scheduler.state_dict()}, os.path.join(PATH, 'checkpoint.bin'))
