import os
import pickle
import tempfile
import numpy as np
import matplotlib.pyplot as plt

from PIL import Image
from tqdm import tqdm

def update_plot(
        eta, u, v, ax,
        max_quivers,
        x, y, n_x, n_y,
        vmin, vmax
    ):

    quiver_stride = (
        slice(1, -1, n_y // max_quivers),
        slice(1, -1, n_x // max_quivers)
    )

    ax.clear()
    cs = ax.pcolormesh(
        x[1:-1] / 1e3,
        y[1:-1] / 1e3,
        eta[1:-1, 1:-1],
        vmin=vmin, vmax=vmax, cmap='RdBu_r'
    )

    if np.any((u[quiver_stride] != 0) | (v[quiver_stride] != 0)):
        ax.quiver(
            x[quiver_stride[1]] / 1e3,
            y[quiver_stride[0]] / 1e3,
            u[quiver_stride],
            v[quiver_stride],
            clip_on=False
        )

    ax.set_aspect('equal')
    ax.set_xlabel('$x$ (km)')
    ax.set_ylabel('$y$ (km)')

    return cs

def draw_png(sample: np.ndarray, folder: str) -> None:
    max_quivers = 21
    n_x = n_y = 32
    dx = dy = 20e3
    x, y = np.arange(n_x) * dx, np.arange(n_y) * dy

    _, ax = plt.subplots(1, 1, figsize=(8, 6))
    # v_min, v_max = sample[:, 0].min(), sample[:, 0].max()
    v_min, v_max = -1.5, 1.5

    eta, norm_u, norm_v = sample[0]
    cs = update_plot(
        eta, norm_u, norm_v, ax,
        max_quivers,
        x, y, n_x, n_y,
        v_min, v_max
    )
    plt.colorbar(cs, label='$\\eta$ (m)')
    plt.savefig(os.path.join(folder, '0.png'), format='png')

    for i, (eta, norm_u, norm_v) in enumerate(sample[1:]):
        update_plot(
            eta, norm_u, norm_v, ax,
            max_quivers,
            x, y, n_x, n_y,
            v_min, v_max
        )
        plt.savefig(os.path.join(folder, f'{i+1}.png'), format='png')



path = 'logs/pde_1.0---hidden_size-16---2024_08_02__20_16_09/sample-2024_08_03__13_29_13/samples_all_pde_7.80945.pkl'

gif = True
save_folder = 'temp'
data = pickle.load(open(path, 'rb'))

for idx in tqdm(range(5)):
    sample = data[idx]
    if not gif:
        os.makedirs(os.path.join(save_folder, str(idx)), exist_ok=True)
        draw_png(sample, os.path.join(save_folder, str(idx)))
    else:
        with tempfile.TemporaryDirectory() as temp_dir:
            draw_png(sample, temp_dir)

            image_files = [f for f in os.listdir(temp_dir) if f.endswith('.png')]
            image_files.sort(key=lambda x: int(x[:-4]))

            first_image = Image.open(os.path.join(temp_dir, image_files[0]))
            frames = [Image.open(os.path.join(temp_dir, img)) for img in image_files]

            output_path = os.path.join('/'.join(path.split('/')[:-1]), f'output_{idx}.gif')
            first_image.save(output_path, save_all=True, append_images=frames[1:], loop=0, duration=100)





