import os
from tqdm import tqdm, trange
from phi.torch.flow import *
import pylab
import matplotlib.pyplot as plt
import numpy as np

TORCH.set_default_device('GPU')
TORCH.seed(1234)
DT = 0.1
NU = 0.01
np.random.seed(1234)


def step(velocity, pressure, dt=1.0, buoyancy_factor=0.1):
    velocity = advect.semi_lagrangian(velocity, velocity, dt)

    velocity = diffuse.explicit(velocity, NU, dt)
    velocity, pressure = fluid.make_incompressible(velocity)
    return velocity, pressure


initial_v_xs = []
initial_v_ys = []
final_v_xs = []
final_v_ys = []
os.makedirs('data', exist_ok=True)
for i in trange(1100):
    velocity = StaggeredGrid(Noise(), extrapolation.PERIODIC, x=64, y=64, bounds=Box[0:128, 0:128])
    pressure = None
    for time_step in trange(10):
        velocity, pressure = step(velocity, pressure, dt=DT)
        if time_step == 0:
            initial_v_xs.append(velocity.values.vector[0].numpy('y,x'))
            initial_v_ys.append(velocity.values.vector[1].numpy('y,x'))

    final_v_xs.append(velocity.values.vector[0].numpy('y,x'))
    final_v_ys.append(velocity.values.vector[1].numpy('y,x'))

np.savez('data/navier2d.npz', initial_v_x=np.stack(initial_v_xs, 0), initial_v_y=np.stack(initial_v_ys, 0),
         final_v_x=np.stack(final_v_xs, 0),
         final_v_y=np.stack(final_v_ys, 0))
