import os
import pickle
import numpy as np
import matplotlib.pyplot as plt

from tqdm import tqdm

def update_plot(
        eta, u, v, ax,
        max_quivers,
        x, y, n_x, n_y,
        vmin, vmax
    ):

    quiver_stride = (
        slice(1, -1, n_y // max_quivers),
        slice(1, -1, n_x // max_quivers)
    )

    ax.clear()
    cs = ax.pcolormesh(
        x[1:-1] / 1e3,
        y[1:-1] / 1e3,
        eta[1:-1, 1:-1],
        vmin=vmin, vmax=vmax, cmap='RdBu_r'
    )

    stride = 2
    if np.any((u[quiver_stride] != 0) | (v[quiver_stride] != 0)):
        ax.quiver(
            x[quiver_stride[1]] / 1e3,
            y[quiver_stride[0]] / 1e3,
            u[quiver_stride],
            v[quiver_stride],
            clip_on=False
        )

    ax.set_aspect('equal')
    ax.set_xlabel('$x$ (km)')
    ax.set_ylabel('$y$ (km)')

    return cs

def draw_png(sample: np.ndarray, folder: str, skip: int) -> None:
    max_quivers = 21
    n_x = n_y = 32
    dx = dy = 20e3
    x, y = np.arange(n_x) * dx, np.arange(n_y) * dy

    _, ax = plt.subplots(1, 1, figsize=(8, 6))
    # v_min, v_max = sample[:, 0].min(), sample[:, 0].max()
    v_min, v_max = -1.5, 1.5

    eta, norm_u, norm_v = sample[0]
    cs = update_plot(
        eta, norm_u, norm_v, ax,
        max_quivers,
        x, y, n_x, n_y,
        v_min, v_max
    )
    plt.colorbar(cs, label='$\\eta$ (m)')
    # plt.savefig(os.path.join(folder, '0.png'), format='png')

    for i, (eta, norm_u, norm_v) in enumerate(sample[1:]):
        if i % skip == 0:
            update_plot(
                eta, norm_u, norm_v, ax,
                max_quivers,
                x, y, n_x, n_y,
                v_min, v_max
            )
            plt.savefig(os.path.join(folder, f'{i+1}.png'), format='png')



# path = 'data/train.pkl'
# data = np.stack([sample['data'] for sample in pickle.load(open(path, 'rb'))['data']])
# save_folder = 'figs/train'

path = 'logs/pde_1.0---hidden_size-16---2024_08_02__20_16_09/sample-2024_08_03__13_29_13/samples_all_pde_7.80945.pkl'
data = pickle.load(open(path, 'rb'))
save_folder = 'figs/gen'


for idx in tqdm(range(5)):
    os.makedirs(os.path.join(save_folder, str(idx)), exist_ok=True)
    draw_png(data[idx], os.path.join(save_folder, str(idx)), skip=5)



