import os
from config import Config
from model import NeuralAdvection
from sources import get_source
from utils import *


# create experiment config containing all hyperparameters
cfg = Config()

# create network and training agent
neuadv = NeuralAdvection(cfg)

# load checkpoints
if cfg.ckpt > 0:
    neuadv.load_ckpt(cfg.ckpt)

# add source
try:
    neuadv.load_ckpt('add_source')
    print("load pretrained model that fits initial condition.")
except Exception as e:
    # get source function
    source_func = get_source(cfg.src, mu=cfg.offset)
    neuadv.add_source(source_func)

# save initial 
grid_v, grid_samples = neuadv.sample_field(cfg.vis_resolution, to_numpy=True)
y_max = np.max(grid_v)
fig = neuadv.draw_field(cfg.vis_resolution, y_max=y_max)
save_path = os.path.join(cfg.results_dir, f'{0:03d}.png')
save_figure(fig, save_path)

# start simulation
grid_values = []
for t in range(1, cfg.n_timesteps + 1):
    neuadv.timestep += 1

    # time-stepping
    print("timestep:", neuadv.timestep)
    neuadv.advect()

    # save visualization
    fig = neuadv.draw_field(cfg.vis_resolution, y_max=y_max)
    save_path = os.path.join(cfg.results_dir, f'{t:03d}.png')
    save_figure(fig, save_path)

    # save grid values
    if cfg.save_h5:
        grid_v, grid_samples = neuadv.sample_field(cfg.sample_resolution, to_numpy=True)
        grid_values.append(grid_v)
    
    # save checkpoints
    neuadv.save_ckpt()

# frames to gif
save_path = os.path.join(cfg.results_dir, 'anim.gif')
frames2gif(cfg.results_dir, save_path, fps=cfg.fps)

# save grid values
if cfg.save_h5:
    save_path = os.path.join(cfg.results_dir, f'girdV_res{cfg.sample_resolution}.npy')
    grid_values = np.stack(grid_values, axis=0)
    np.save(save_path, grid_values)
