from generate import generate_ns_data

import json
import sys
import copy
from datetime import datetime
import random
import argparse
import numpy as np
import torch
import torch.optim as optim
import matplotlib.pyplot as plt
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
plt.rcParams["animation.html"] = "jshtml"



if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Hyper-parameters of data generation')

    parser.add_argument('--device', type=str, default='cuda:0',
                        help='Used device')
    
    parser.add_argument('--mode', type=str, default='train',
            help='train or test')
    
    parser.add_argument('--s', type=int, default=256,
            help='data original size')

    parser.add_argument('--sub', type=int, default=4,#4
            help='ratio of down sampling')
    
    parser.add_argument('--N', type=int, default=5,
            help='Number of the data generation')
    
    parser.add_argument('--nu', type=float, default=1e-4,
            help='1/Re in NSE')
    
    parser.add_argument('--T', type=float, default=1.0,
            help='final time')

    parser.add_argument('--dt', type=float, default=1e-4,
            help='dt')

    parser.add_argument('--record_ratio', type=int, default=200,
            help='record ratio')
    
    parser.add_argument('--f_name', type=str, default='kf',
            help='the name of forcing')
    
    cfg = parser.parse_args()
    generate_ns_data(cfg)