import torch
import numpy as np
import random
import math
import json
import sys
import os
import time
from nse import navier_stokes_2d, GaussianRF
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     np.random.seed(seed)
     random.seed(seed)
     os.environ['PYTHONHASHSEED'] = str(seed) 
     torch.backends.cudnn.deterministic = True
     torch.backends.cudnn.benchmark = False
     torch.backends.cudnn.enabled = True


def generate_ns_data(cfg):
    device = cfg.device
    s = cfg.s
    sub = cfg.sub
    dt = cfg.dt
    T = cfg.T
    record_steps = int(cfg.record_ratio * T)
    N = cfg.N
    nu = cfg.nu
    f_name = cfg.f_name
    mode = cfg.mode

    if mode == 'train':
        setup_seed(0)
    if mode == 'test':
        setup_seed(1)
    if mode == 'val':
        setup_seed(2)
        
    data_save_path = f'dataset'
    if not os.path.exists(data_save_path):
        os.makedirs(data_save_path)
    
    log_save_path = f'log'
    if not os.path.exists(log_save_path):
        os.makedirs(log_save_path)
    with open(f'{log_save_path}/ns_{mode}_nu_{nu}_f_{f_name}.txt', 'w') as f:
        json.dump(cfg.__dict__, f, indent=2)
    sys.stdout.flush()
    
    GRF = GaussianRF(s, device=device)
    
    t = torch.linspace(0, 1, s+1, device=device)
    t = t[0: -1]
    X, Y = torch.meshgrid(t, t, indexing='ij')
    if f_name == 'li':
        f = 0.1 * (torch.sin(2*math.pi*(X + Y)) + torch.cos(2*math.pi*(X + Y))).to(device)
    elif f_name == 'kf':
        f = 0.1 * torch.cos(8*math.pi*X).to(device)
    else:
        f = 0.0 * X.to(device)

    bsize = min(100, N)
    c = 0
    u = torch.zeros(N, s//sub, s//sub, record_steps+1)
    

    for j in range(N//bsize):
        # Sample random feilds
        w0 = GRF(bsize)
        visc = nu * torch.ones(bsize, device=device)
        sol, sol_t = navier_stokes_2d(w0, f, visc, T, dt, record_steps)
        w0 = w0[:, ::sub, ::sub].reshape(-1,  s//sub, s//sub, 1)
        sol = torch.concat([w0, sol[:, ::sub, ::sub, :]], dim=3)
        u[c:(c+bsize),...] = sol
        c += bsize
        print(j, c)
        print(u.max())
    torch.save(u, f'{data_save_path}/ns_{mode}_nu_{nu}_f_{f_name}')