# https://www.psl.noaa.gov/data/gridded/data.noaa.oisst.v2.html
# https://psl.noaa.gov/repository/entry/show/PSL+Climate+Data+Repository/Public/PSL+Datasets/NOAA+OI+SST/Weekly+and+Monthly/sst.wkmean.1990-present.nc?entryid=12159560-ab82-48a1-b3e4-88ace20475cd&output=data.cdl
# https://psl.noaa.gov/repository/entry/show?entryid=12159560-ab82-48a1-b3e4-88ace20475cd
# https://downloads.psl.noaa.gov/Datasets/noaa.oisst.v2.highres/

import os

from matplotlib import pyplot as plt
from matplotlib.animation import FuncAnimation
from mpl_toolkits.basemap import Basemap
import numpy as np
import torch
from tqdm import tqdm

from data.climate import SSTDataset


def vis_dataset():
    sst_dataset = SSTDataset(chunk_size=208)
    sst, lat, lon = sst_dataset.data['sst'], sst_dataset.data['lat'], sst_dataset.data['lon']
    sst = sst.reshape([sst_dataset.lat_dim, sst_dataset.lon_dim, sst_dataset.time_steps]).transpose(2, 0, 1)
    sst = (sst * sst_dataset.std + sst_dataset.mean) * 1e-2

    sst = sst[:10]
    time_steps = len(sst)

    # Create a figure and axis
    fig, ax = plt.subplots(figsize=(8., 4.))

    # Define map projection and boundaries
    basemap = Basemap(
        projection='cyl',  # 'cyl' is a cylindrical equidistant projection
        llcrnrlat=min(lat),
        urcrnrlat=max(lat),
        llcrnrlon=min(lon),
        urcrnrlon=max(lon),
        resolution='c',
        ax=ax,
    )

    # Convert latitude and longitude to map coordinates
    lon_grid, lat_grid = np.meshgrid(lon, lat)  # Create a meshgrid for lon and lat
    x, y = basemap(lon_grid, lat_grid)  # Convert to map projection coordinates
    # Initialize the first frame of the plot
    c_scheme = basemap.pcolormesh(x, y, sst[0], cmap='viridis', shading='auto', vmin=-5., vmax=40.)
    # Add a color bar
    cbar = basemap.colorbar(c_scheme, location='right', pad='1%')
    cbar.set_label('Temperature [°C]')
    plt.gcf().set_facecolor((1., 1., 1., 0.))
    plt.tight_layout()

    # Function to update the plot for each frame
    def update(frame: int):
        plt.cla()  # Clear the previous plot
        c_scheme = basemap.pcolormesh(x, y, sst[frame], cmap='viridis', shading='auto')
        basemap.fillcontinents(color='white', lake_color='white')
        # draw a boundary around the map, fill the background. this background will end up being the ocean color, since the continents will be drawn on top.
        # m.drawmapboundary(fill_color='aqua')
        plt.title(f'Weekly Mean of Sea Surface Temperature - Time Step {frame:04}')
        return c_scheme,

    # Create the animation object
    anim = FuncAnimation(fig, update, frames=tqdm(range(time_steps)), blit=False, repeat=True)
    # To save as a video file
    anim.save('geology_data_animation.mp4', writer='ffmpeg', dpi=100)

    # Show the plot
    plt.show()


def load_saved(log_dir: str, prefix: str, num: int):
    data = []
    for i in range(num):
        data_i = torch.load(os.path.join(log_dir, f'{prefix}_{i}.pth'), weights_only=True)
        data.append(data_i)
    data = torch.cat(data, dim=0).reshape(180, 360, -1).permute(2, 0, 1)
    return data


def main():
    sst_dataset = SSTDataset(chunk_size=208)
    sst, lat, lon = sst_dataset.data['sst'], sst_dataset.data['lat'], sst_dataset.data['lon']
    sst = sst.reshape([sst_dataset.lat_dim, sst_dataset.lon_dim, sst_dataset.time_steps]).transpose(2, 0, 1)
    sst = (sst * sst_dataset.std + sst_dataset.mean) * 1e-2
    # time_steps = len(sst)

    log_dir = '../logs/saved/default'
    gt_ = load_saved(log_dir, 'gt', 10)
    gt = (gt_ * sst_dataset.std + sst_dataset.mean) * 1e-2
    assert (gt - sst[-len(gt):]).abs().max() == 0.

    pred_ = load_saved(log_dir, 'pred', 10)
    pred = (pred_ * sst_dataset.std + sst_dataset.mean) * 1e-2

    mse_error = ((pred_[:208] - gt_[:208]) ** 2.).mean()
    print('mse', mse_error)


    # Create a figure and axis
    fig, ax = plt.subplots(figsize=(4.5, 3))

    # Define map projection and boundaries
    basemap = Basemap(
        projection='cyl',  # 'cyl' is a cylindrical equidistant projection
        llcrnrlat=min(lat),
        urcrnrlat=max(lat),
        llcrnrlon=min(lon),
        urcrnrlon=max(lon),
        resolution='c',
        ax=ax,
    )

    # Convert latitude and longitude to map coordinates
    lon_grid, lat_grid = np.meshgrid(lon, lat)  # Create a meshgrid for lon and lat
    x, y = basemap(lon_grid, lat_grid)  # Convert to map projection coordinates
    # Initialize the first frame of the plot
    c_scheme = basemap.pcolormesh(x, y, gt[-1], cmap='viridis', shading='auto', vmin=-5., vmax=40., rasterized=True)
    # Add a color bar
    cbar = basemap.colorbar(c_scheme, location='right', pad='1%')
    # cbar.set_label('Temperature [°C]')
    plt.gcf().set_facecolor((1., 1., 1., 0.))
    plt.tight_layout()

    basemap.fillcontinents(color='white', lake_color='white')
    # plt.title(f'Weekly Mean of Sea Surface Temperature on 2023/01/29 (Ground Truth)')
    plt.title(f'Weekly Mean Temperature [°C]')

    plt.savefig('vis/sst_gt.pdf', bbox_inches='tight', pad_inches=0., transparent=False)
    plt.show()


    # Create a figure and axis
    fig, ax = plt.subplots(figsize=(4.5, 3))

    # Define map projection and boundaries
    basemap = Basemap(
        projection='cyl',  # 'cyl' is a cylindrical equidistant projection
        llcrnrlat=min(lat),
        urcrnrlat=max(lat),
        llcrnrlon=min(lon),
        urcrnrlon=max(lon),
        resolution='c',
        ax=ax,
    )

    # Convert latitude and longitude to map coordinates
    lon_grid, lat_grid = np.meshgrid(lon, lat)  # Create a meshgrid for lon and lat
    x, y = basemap(lon_grid, lat_grid)  # Convert to map projection coordinates
    # Initialize the first frame of the plot
    c_scheme = basemap.pcolormesh(x, y, pred[-1], cmap='viridis', shading='auto', vmin=-5., vmax=40., rasterized=True)
    # Add a color bar
    cbar = basemap.colorbar(c_scheme, location='right', pad='1%')
    # cbar.set_label('Temperature [°C]')
    plt.gcf().set_facecolor((1., 1., 1., 0.))
    plt.tight_layout()

    basemap.fillcontinents(color='white', lake_color='white')
    # plt.title(f'Weekly Mean of Sea Surface Temperature on 2023/01/29 (Prediction)')
    plt.title(f'Weekly Mean Temperature [°C]')

    plt.savefig('vis/sst_pred.pdf', bbox_inches='tight', pad_inches=0., transparent=False)
    plt.show()



    # Create a figure and axis
    fig, ax = plt.subplots(figsize=(4.5, 3))

    # Define map projection and boundaries
    basemap = Basemap(
        projection='cyl',  # 'cyl' is a cylindrical equidistant projection
        llcrnrlat=min(lat),
        urcrnrlat=max(lat),
        llcrnrlon=min(lon),
        urcrnrlon=max(lon),
        resolution='c',
        ax=ax,
    )

    # Convert latitude and longitude to map coordinates
    lon_grid, lat_grid = np.meshgrid(lon, lat)  # Create a meshgrid for lon and lat
    x, y = basemap(lon_grid, lat_grid)  # Convert to map projection coordinates
    # Initialize the first frame of the plot
    # c_scheme = basemap.pcolormesh(x, y, pred[-1] - gt[-1], cmap='viridis', shading='auto', vmin=-5., vmax=4., rasterized=True)
    # c_scheme = basemap.pcolormesh(x, y, (pred[-1] - gt[-1]).abs(), cmap='viridis', shading='auto', vmin=0., vmax=5., rasterized=True)
    c_scheme = basemap.pcolormesh(x, y, pred[-1] - gt[-1], cmap='bwr', shading='auto', vmin=-5., vmax=5., rasterized=True)
    # Add a color bar
    cbar = basemap.colorbar(c_scheme, location='right', pad='1%')
    # cbar.set_label('Error [°C]')
    plt.gcf().set_facecolor((1., 1., 1., 0.))
    plt.tight_layout()

    # basemap.fillcontinents(color='white', lake_color='white')
    basemap.fillcontinents(color='black', lake_color='black')
    plt.title('Error [°C]')

    plt.savefig('vis/sst_diff.pdf', bbox_inches='tight', pad_inches=0., transparent=False)
    plt.show()


    x = 0


if __name__ == '__main__':
    main()
