from SINO import *
from train import train

import json
import sys
import copy
from datetime import datetime
import random
import argparse
import numpy as np
import torch
import torch.optim as optim
import matplotlib.pyplot as plt
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
plt.rcParams["animation.html"] = "jshtml"
cpu_num = 4 
os.environ["OMP_NUM_THREADS"] = str(cpu_num) 
os.environ["MKL_NUM_THREADS"] = str(cpu_num) 
torch.set_num_threads(cpu_num)


def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     np.random.seed(seed)
     random.seed(seed)
     os.environ['PYTHONHASHSEED'] = str(seed) 
     torch.backends.cudnn.deterministic = True
     torch.backends.cudnn.benchmark = False
     torch.backends.cudnn.enabled = True


def main(cfg):
    model_version = cfg.model_version
    nu, f = cfg.nu, cfg.f
    if not os.path.exists(f'log'):
        os.mkdir(f'log')
    if not os.path.exists(f'log/log_model_{model_version}_nu_{nu}_f_{f}'):
        os.mkdir(f'log/log_model_{model_version}_nu_{nu}_f_{f}')
    if not os.path.exists(f'model'):
        os.mkdir(f'model')
    setup_seed(cfg.seed)
    dateTimeObj = datetime.now()
    timestring = f'{dateTimeObj.date().month}_{dateTimeObj.date().day}_{dateTimeObj.time().hour}_{dateTimeObj.time().minute}_{dateTimeObj.time().second}'
    cfg.model_name = f'Train_num-{cfg.num_train}net-seed-{cfg.seed}-{cfg.model_version}-num_step1-{cfg.num_step1}-num_step2-{cfg.num_step2}-lr-{cfg.lr}-channel-{cfg.channel}-k_num-{cfg.k_num}-dt-{cfg.dt}-step_forward-{cfg.step_forward}-{timestring}'
    logfile = f'log/log_model_{model_version}_nu_{nu}_f_{f}/Train_num-{cfg.num_train}log-seed-{cfg.seed}-{cfg.model_version}-num_step1-{cfg.num_step1}-num_step2-{cfg.num_step2}-lr-{cfg.lr}-channel-{cfg.channel}-k_num-{cfg.k_num}-dt-{cfg.dt}-step_forward-{cfg.step_forward}-{timestring}.csv'
    sys.stdout = open(logfile, 'w')

    print('--------args----------')
    for k in list(vars(cfg).keys()):
        print('%s: %s' % (k, vars(cfg)[k]))
    print('--------args----------\n')
    
    if model_version == 'SINOmodel':
        net = SINOmodel(input_dim=cfg.input_dim, channel=cfg.channel, k_num=cfg.k_num, output_dim=cfg.output_dim, dt=cfg.dt, step_forward=cfg.step_forward, anti_alias_ratio=cfg.anti_alias_ratio)

    sys.stdout.flush()
    train(cfg, net)

    final_time = datetime.now()
    print(f'FINAL_TIME_{final_time.date().month}_{final_time.date().day}_{final_time.time().hour}_{final_time.time().minute}_{final_time.time().second}')

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Hyper-parameters of training')

    parser.add_argument('--data_path', type=str, 
                        default='/data/',
                        help='path of data') 

    parser.add_argument('--nu', type=float, default=1e-4,
            help='nu')

    parser.add_argument('--f', type=str, default='kf',
            help='forcing')

    parser.add_argument('--model_name', type=str, default='SINOmodel',
            help='model name')

    parser.add_argument('--device', type=str, default='cuda:5',
                        help='Used device')
    
    parser.add_argument('--seed', type=int, default=0,
            help='seed')
    
    # model hyper-parameter
    parser.add_argument('--model_version', type=str, default='SINOmodel',
            help='model_version')
    
    parser.add_argument('--input_dim', type=int, default=1,
            help='input_dim')
    
    parser.add_argument('--channel', type=int, default=32,
            help='channel')
    
    parser.add_argument('--k_num', type=int, default=8,
            help='k_num')
    
    parser.add_argument('--output_dim', type=int, default=1,
            help='output_dim')
    
    parser.add_argument('--dt', type=float, default=0.005,
            help='dt, 1/dt should be int')

    parser.add_argument('--step_forward', type=str, default='rk4',
            help='step_forward')

    parser.add_argument('--anti_alias_ratio', type=float, default=2.0/3,
            help='anti_alias_ratio')
    
    # hyper-parameter during training
    parser.add_argument('--num_train', type=int, default=5,
            help='number of training data')
            
    parser.add_argument('--num_step1', type=int, default=4,
            help='num_step1')

    parser.add_argument('--num_step2', type=int, default=8,
            help='num_step2')

    parser.add_argument('--lr', type=float, default=0.001,
            help='lr of optim')

    parser.add_argument('--weight_decay', type=float, default=0.00,
            help='lr of optim')

    parser.add_argument('--num_iterations', type=int, default=20000,
            help='num_iterations of optim')

    cfg = parser.parse_args()
    cfg.device = 'cuda:4'
    cfg.lr = 0.01
    cfg.channel = 32
    cfg.f = 'kf'
    cfg.nu = 1e-05
    for num_step1 in [ 4 ]:
        for num_step2 in [  8 ]:
            cfg.num_step1 = num_step1
            cfg.num_step2 = num_step2
            main(cfg)
            
            
