import os
import numpy as np
from config import Config
from models import get_model
from sources import get_source_velocity, get_source_density
from utils.vis_utils import save_figure, frames2gif
from utils.file_utils import ensure_dirs


# create experiment config containing all hyperparameters
cfg = Config()

# create network and training agent
fluid = get_model(cfg)

# load checkpoints
if cfg.ckpt > 0:
    fluid.load_ckpt(cfg.ckpt)

# add source
try:
    fluid.load_ckpt('add_source')
    print("load pretrained model that fits initial condition.")
except Exception as e:
    # get source function
    if cfg.use_density:
        source_func = get_source_density(cfg.src)
        fluid.add_source_density('density', source_func)
    source_func = get_source_velocity(cfg.src, cfg.src_start_frame)
    fluid.add_source('velocity', source_func, is_init=True)

# vis results save folder
vis_vel_dir = os.path.join(cfg.results_dir, 'velocity')
vis_vor_dir = os.path.join(cfg.results_dir, 'vorticity')
ensure_dirs([vis_vel_dir, vis_vor_dir])

# start simulation
grid_values = []
for t in range(cfg.n_timesteps):
    fluid.timestep += 1
    if t > 0 and t < cfg.src_duration:
        fluid.add_source('velocity', source_func, is_init=False)

    # time-stepping
    print("timestep:", fluid.timestep)
    fluid.step()

    # save visualization
    fig = fluid.draw('velocity', cfg.vis_resolution)
    save_path = os.path.join(vis_vel_dir, f'velocity_t{t:03d}.png')
    save_figure(fig, save_path)
    fig = fluid.draw('vorticity', cfg.vis_resolution)
    save_path = os.path.join(vis_vor_dir, f'vorticity_t{t:03d}.png')
    save_figure(fig, save_path)

    # save grid values
    if cfg.save_h5:
        grid_v = fluid.sample_velocity_field(cfg.sample_resolution, to_numpy=True)
        grid_values.append(grid_v)
    
    # save checkpoints
    fluid.save_ckpt()

# frames to gif
save_path = os.path.join(vis_vel_dir, 'velocity_anim.gif')
frames2gif(vis_vel_dir, save_path, fps=cfg.fps)
save_path = os.path.join(vis_vor_dir, 'vorticity_anim.gif')
frames2gif(vis_vor_dir, save_path, fps=cfg.fps)

# save grid values
if cfg.save_h5:
    save_path = os.path.join(cfg.results_dir, f'girdV_res{cfg.sample_resolution}.npy')
    grid_values = np.stack(grid_values, axis=0)
    np.save(save_path, grid_values)
