import os
import pickle
import numpy as np
import matplotlib.pyplot as plt

from tqdm import tqdm

import numpy as np
from numpy.random import uniform as uniform

def update_edge_(mat: np.ndarray) -> np.ndarray:
    mat[0, 1:-1] = mat[1, 1:-1]
    mat[-1, 1:-1] = mat[-2, 1:-1]
    
    mat[1:-1, 0] = mat[1:-1, 1]
    mat[1:-1, -1] = mat[1:-1, -2]
    
    mat[0, 0] = mat[1, 0]
    mat[0, -1] = mat[1, -1]
    mat[-1, 0] = mat[-2, 0]
    mat[-1, -1] = mat[-2, -1]
    return mat

def update_plot(
        t, h, u, v, ax,
        depth, max_quivers,
        x, y, n_xy
    ):
    eta = h - depth

    quiver_stride = (
        slice(1, -1, n_xy // max_quivers),
        slice(1, -1, n_xy // max_quivers)
    )

    # plot_range = 0.5
    plot_range = 1.5
    ax.clear()
    cs = ax.pcolormesh(
        x[1:-1] / 1e3,
        y[1:-1] / 1e3,
        eta[1:-1, 1:-1],
        vmin=-plot_range, vmax=plot_range, cmap='RdBu_r'
    )

    if np.any((u[quiver_stride] != 0) | (v[quiver_stride] != 0)):
        ax.quiver(
            x[quiver_stride[1]] / 1e3,
            y[quiver_stride[0]] / 1e3,
            u[quiver_stride],
            v[quiver_stride],
            clip_on=False
        )

    ax.set_aspect('equal')
    ax.set_xlabel('$x$ (km)')
    ax.set_ylabel('$y$ (km)')
    ax.set_title(
        't=%5.2f days'
        % (t / 86400)
    )
    plt.pause(0.1)
    return cs

def iterate_shallow_water(
        n_xy,
        h0, u0, v0,
        coriolis_param, depth, gravity,
        dx, dy, dt
    ):
    # allocate arrays
    u, v, h = np.empty((n_xy, n_xy)), np.empty((n_xy, n_xy)), np.empty((n_xy, n_xy))

    # initial conditions
    h[...] = h0
    u[...] = u0
    v[...] = v0

    # time step equations
    while True:
        # update u
        v_avg = 0.25 * (v[1:-1, 1:-1] + v[:-2, 1:-1] + v[1:-1, 2:] + v[:-2, 2:])
        u[1:-1, 1:-1] = u[1:-1, 1:-1] + dt * (
            + coriolis_param * v_avg
            - gravity * (h[1:-1, 2:] - h[1:-1, 1:-1]) / dx
        )
        # u[:, -2] = 0

        # update v
        u_avg = 0.25 * (u[1:-1, 1:-1] + u[1:-1, :-2] + u[2:, 1:-1] + u[2:, :-2])
        v[1:-1, 1:-1] = v[1:-1, 1:-1] + dt * (
            - coriolis_param * u_avg
            - gravity * (h[2:, 1:-1] - h[1:-1, 1:-1]) / dy
        )
        # v[-2, :] = 0

        # update h
        h[1:-1, 1:-1] = h[1:-1, 1:-1] - dt * depth * (
            (u[1:-1, 1:-1] - u[1:-1, :-2]) / dx
            + (v[1:-1, 1:-1] - v[:-2, 1:-1]) / dy
        )

        h = update_edge_(h)

        yield h, u, v

def generate(n_samples: int, save_name: str, plot: bool = False) -> None:
    assert save_name.endswith('.pkl')
    
    max_quivers = 21
    phase_speed = 31
    coriolis_param = 2e-4

    n_xy = 32
    dx = dy = 20e3
    x, y = np.arange(n_xy) * dx, np.arange(n_xy) * dy
    Y, X = np.meshgrid(y, x, indexing='ij')
    dt = 0.5 * min(dx, dy) / phase_speed

    samples = {
        'grid_size': n_xy,
        'dx': dx, 'dy': dy, 'dt': dt,
        'max_quivers': max_quivers,
        'phase_speed': phase_speed,
        'coriolis_param': coriolis_param,
        'data': []
    }
    for _ in tqdm(range(int(n_samples))):
        sample = {}
        gravity = uniform(7, 13)
        depth = uniform(60, 140)
        rossby_radius = np.sqrt(gravity * depth) / coriolis_param

        sample['gravity'] = gravity
        sample['depth'] = depth
        sample['rossby_radius'] = rossby_radius


        h0 = depth + uniform(0.5, 1.5) * (
            np.sin(X / dx / 4) + uniform(0.5, 1.5) * np.sin(Y / dx / 4)
        )
        u0 = np.zeros_like(h0)
        v0 = np.zeros_like(h0)

        if plot:
            _, ax = plt.subplots(1, 1, figsize=(8, 6))
            cs = None

        model = iterate_shallow_water(
            n_xy,
            h0, u0, v0,
            coriolis_param, depth, gravity,
            dx, dy, dt
        )
        data = []
        for iteration, (h, u, v) in enumerate(model):
            if iteration < 50:
                continue
            elif iteration >= 100:
                break
            if plot:
                t = iteration * dt
                cs = update_plot(
                    t, h, u, v, ax,
                    depth, max_quivers,
                    x, y, n_xy
                )
                if cs is None:
                    plt.colorbar(cs, label='$\\eta$ (m)')
            data.append(np.copy(np.stack([h - depth, u, v])))
        # print(np.abs(data[-1][0]).max(), np.abs(data[-1][1]).max(), np.abs(data[-1][2]).max())

            
        sample['data'] = np.stack(data)
        assert not (np.isnan(sample['data'])).any()
        samples['data'].append(sample)

    pickle.dump(samples, open(save_name, 'wb'))
    print('Done')

if __name__ == '__main__':
    os.makedirs('data', exist_ok=True)
    
    np.random.seed(42)
    generate(1e4, 'data/train.pkl', plot=False)

    np.random.seed(420)
    generate(1e3, 'data/val.pkl')

    np.random.seed(4200)
    generate(1e3, 'data/test.pkl')


