"""
Creates plots of ARSO events using cartopy library
"""

import numpy as np
import matplotlib.pyplot as plt
from experiments.arso.display.display import get_cmap
from pyproj import CRS, Transformer
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation


def get_lon_lat(
    nx=401,
    ny=301,
    dx=1000,
    dy=1000,
    projstring="+proj=lcc +lat_0=46.12 +lon_0=14.815 +lat_1=46.12 +lat_2=46.12 +R=6371000 +x_0=4000 +y_0=6000 +units=m +no_defs",
):
    """
    Dynamically generate lon and lat arrays for the Slovenian radar grid.

    Parameters
    ----------
    nx, ny : int
        Number of x (columns) and y (rows) pixels (default 401, 301).
    dx, dy : float
        Pixel spacing in meters along x, y directions (default 1000, 1000).
    projstring : str
        PROJ string describing our Lambert Conformal projection.

    Returns
    -------
    lon : np.ndarray of shape (ny, nx)
        Longitudes at each pixel center.
    lat : np.ndarray of shape (ny, nx)
        Latitudes at each pixel center.
    """
    # Create the Lambert Conformal projection from the string
    crs_lcc = CRS.from_string(projstring)

    # Transformer that converts from our LCC (x,y) to WGS84 (lon, lat)
    transformer = Transformer.from_crs(
        crs_from=crs_lcc, crs_to=CRS.from_epsg(4326), always_xy=True
    )

    # Create a meshgrid of x,y coordinates in our LCC projection
    X, Y = np.meshgrid(
        np.linspace(-(nx - 1) / 2, (nx - 1) / 2, nx) * dx,
        np.linspace(-(ny - 1) / 2, (ny - 1) / 2, ny) * dy,
        indexing="xy",
    )

    # Transform to lon/lat
    lon, lat = transformer.transform(X, Y)
    return lon, lat


def make_animation_arso(
    frames,  # shape: (T, ny, nx)
    img_type="zm",  # Default image type
    interval=200,
    title="ARSO Radar",
    cartopy_features=True,
    **imshow_kwargs,
):
    """
    Creates an animation of ARSO/Slovenia frames using Cartopy in PlateCarree with a custom colormap.

    Parameters
    ----------
    frames : np.ndarray
        Shape (T, ny, nx). Radar or precipitation frames over time.
    lon : np.ndarray
        Shape (ny, nx). Longitude array for each pixel.
    lat : np.ndarray
        Shape (ny, nx). Latitude array for each pixel.
    img_type : str, optional
        Type of image to determine colormap and normalization (default='vil').
    interval : int, optional
        Delay between frames in milliseconds (default=200).
    title : str, optional
        Title for the plot.
    **imshow_kwargs : dict
        Additional arguments passed to ax.imshow (overridden by custom cmap, norm, etc.).

    Returns
    -------
    anim : matplotlib.animation.FuncAnimation
        The animation object.
    """
    # Generate longitude and latitude dynamically
    lon, lat = get_lon_lat(nx=401, ny=301, dx=1000, dy=1000)

    # --- 1. Validate shapes ---
    T, ny, nx = frames.shape
    if lon.shape != (ny, nx) or lat.shape != (ny, nx):
        raise ValueError(
            "lon/lat shapes must match frames' spatial dimensions (ny, nx)."
        )

    # --- 2. Retrieve Custom Colormap and Normalization ---
    cmap, norm, vmin, vmax = get_cmap(img_type)

    # Update imshow_kwargs with custom colormap and normalization
    imshow_kwargs.update({"cmap": cmap, "norm": norm, "vmin": vmin, "vmax": vmax})

    # --- 3. Set up the figure and Cartopy axes in PlateCarree. ---
    fig = plt.figure(figsize=(10, 8))
    ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree())

    # Set the extent to focus on Slovenia or our area of interest
    ax.set_extent([lon.min(), lon.max(), lat.min(), lat.max()], crs=ccrs.PlateCarree())

    # Add geographical features
    if cartopy_features:
        ax.add_feature(
            cfeature.OCEAN.with_scale("50m"),
            alpha=0.5,
            zorder=3,
        )
        ax.add_feature(
            cfeature.BORDERS.with_scale("50m"),
            linewidth=0.3,
            edgecolor="black",
            zorder=3,
        )
        ax.add_feature(
            cfeature.COASTLINE.with_scale("50m"),
            linewidth=0.3,
            edgecolor="black",
            zorder=3,
        )
        ax.add_feature(
            cfeature.LAKES.with_scale("50m"),
            edgecolor="cornflowerblue",
            alpha=0.5,
            linewidth=0.3,
            zorder=3,
        )
        ax.add_feature(
            cfeature.RIVERS.with_scale("50m"),
            edgecolor="cornflowerblue",
            alpha=0.5,
            linewidth=0.3,
            zorder=3,
        )

    # --- 4. Display the first frame with custom colormap ---
    extent = [lon.min(), lon.max(), lat.min(), lat.max()]
    im = ax.imshow(
        frames[0, :, :],
        origin="lower",
        extent=extent,
        transform=ccrs.PlateCarree(),
        **imshow_kwargs,
    )

    if title:
        ax.set_title(title)

    # --- 5. Add a Colorbar ---
    cbar = plt.colorbar(
        im, ax=ax, orientation="vertical", pad=0.02, aspect=16, shrink=0.8
    )
    cbar.set_label("Radar Reflectivity (dBz)")  # Change label as appropriate

    # Add time text annotation
    time_text = ax.text(
        0.96,
        0.96,
        "",
        transform=ax.transAxes,
        ha="right",
        va="top",
        fontsize=14,
        color="white",
        bbox=dict(facecolor="black", alpha=0.5, pad=0.2),
    )

    # --- 6. Define init and animate functions. ---
    def init():
        im.set_data(frames[0, :, :])
        time_text.set_text("+5min")
        return (im, time_text)

    def update(frame_idx):
        im.set_data(frames[frame_idx, :, :])
        time_text.set_text(f"+{ (frame_idx + 1) * 5 }min")
        return (im, time_text)

    # --- 7. Create the animation. ---
    anim = FuncAnimation(
        fig, update, init_func=init, frames=range(T), interval=interval, blit=True
    )

    return anim


def plot_mse_map_arso(
    mse: np.ndarray, title: str = "MSE Map", **imshow_kwargs  # shape: (ny, nx)
):
    """
    Plots the (x, y) average Mean Squared Error (MSE) data overlaid on a map using Cartopy in PlateCarree.

    This function uses a custom colormap and normalization (retrieved by `get_cmap`) and generates
    a longitude/latitude grid (using `get_lon_lat`) based on the spatial dimensions of the MSE array.

    Parameters
    ----------
    mse : np.ndarray
        A 2D array with shape (ny, nx) representing the average MSE for each (x, y) coordinate.
    img_type : str, optional
        Type of image to determine colormap and normalization (default: 'vil').
    title : str, optional
        Title for the plot (default: 'MSE Map').
    dx : int, optional
        Grid spacing in the x-direction (default: 1000).
    dy : int, optional
        Grid spacing in the y-direction (default: 1000).
    **imshow_kwargs : dict
        Additional keyword arguments passed to ax.imshow.

    Returns
    -------
    tuple
        A tuple (fig, ax) containing the matplotlib Figure and Axes objects.
    """

    # Generate longitude and latitude arrays dynamically.
    # (Assumes that get_lon_lat accepts nx, ny, dx, dy as arguments.)
    lon, lat = get_lon_lat(nx=401, ny=301, dx=1000, dy=1000)

    # Validate that the lon/lat shapes match the mse spatial dimensions.
    if lon.shape != mse.shape or lat.shape != mse.shape:
        raise ValueError(
            "Longitude and latitude shapes must match the MSE spatial dimensions (ny, nx)."
        )

    # Change the range of MSE to be: [< 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, > 30]
    imshow_kwargs["vmin"] = 3
    imshow_kwargs["vmax"] = 30

    # Set up the figure and Cartopy axes using PlateCarree.
    fig = plt.figure(figsize=(10, 8))
    ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree())

    # Define the extent of the plot based on the lon/lat arrays.
    extent = [lon.min(), lon.max(), lat.min(), lat.max()]
    ax.set_extent(extent, crs=ccrs.PlateCarree())

    # Add geographical features.
    ax.add_feature(cfeature.BORDERS, linewidth=0.5, edgecolor="black")
    ax.add_feature(cfeature.COASTLINE, linewidth=0.5, edgecolor="black")
    ax.add_feature(cfeature.LAKES, alpha=0.5)
    ax.add_feature(cfeature.RIVERS)

    # Display the MSE data using imshow.
    im = ax.imshow(
        mse,
        origin="lower",
        extent=extent,
        transform=ccrs.PlateCarree(),
        **imshow_kwargs,
    )

    # Add a title, if provided.
    if title:
        ax.set_title(title)

    # Add a colorbar.
    cbar = plt.colorbar(
        im, ax=ax, orientation="vertical", pad=0.02, aspect=16, shrink=0.8
    )
    cbar.set_label("MSE")

    return fig, ax


def plot_pair_frames_arso(
    frame1,
    frame2,
    img_type="zm",
    title=None,
    title_frame1="Frame 1",
    title_frame2="Frame 2",
    cartopy_features=True,
    **kwargs,
):
    """
    Plots a comparison of two ARSO frames and returns the figure.

    Parameters
    ----------
    frame1 : numpy.ndarray
        A [H, W] tensor representing the first frame.
    frame2 : numpy.ndarray
        A [H, W] tensor representing the second frame.
    img_type : str, optional
        Image type for colormap selection (default is 'zm').
    title : str, optional
        Title for the plot.
    title_frame1 : str, optional
        Title for the first frame.
    title_frame2 : str, optional
        Title for the second frame.
    cartopy_features : bool, optional
        Whether to add cartopy features like borders and rivers (default is True).
    **kwargs
        Additional arguments to pass to `ax.imshow`.

    Returns
    -------
    matplotlib.figure.Figure
        The figure object.
    """
    # Generate longitude and latitude arrays
    lon, lat = get_lon_lat(nx=401, ny=301, dx=1000, dy=1000)

    # Get the appropriate colormap and normalization
    cmap, norm, vmin, vmax = get_cmap(img_type)

    # Update kwargs with custom colormap and normalization
    kwargs.update({"cmap": cmap, "norm": norm, "vmin": vmin, "vmax": vmax})

    # Create the figure with PlateCarree projection
    fig, axs = plt.subplots(
        1, 2, figsize=(15, 6), subplot_kw={"projection": ccrs.PlateCarree()}
    )

    # Set the extent for both subplots
    extent = [lon.min(), lon.max(), lat.min(), lat.max()]
    axs[0].set_extent(extent, crs=ccrs.PlateCarree())
    axs[1].set_extent(extent, crs=ccrs.PlateCarree())

    # Plot first frame
    im1 = axs[0].imshow(
        frame1, origin="lower", extent=extent, transform=ccrs.PlateCarree(), **kwargs
    )

    # Add geographical features to first subplot
    if cartopy_features:
        axs[0].add_feature(cfeature.BORDERS, linewidth=0.5, edgecolor="black")
        axs[0].add_feature(cfeature.COASTLINE, linewidth=0.5, edgecolor="black")
        axs[0].add_feature(cfeature.LAKES, alpha=0.5)
        axs[0].add_feature(cfeature.RIVERS)
    axs[0].set_title(title_frame1)

    # Plot second frame
    im2 = axs[1].imshow(
        frame2, origin="lower", extent=extent, transform=ccrs.PlateCarree(), **kwargs
    )

    # Add geographical features to second subplot
    if cartopy_features:
        axs[1].add_feature(cfeature.BORDERS, linewidth=0.5, edgecolor="black")
        axs[1].add_feature(cfeature.COASTLINE, linewidth=0.5, edgecolor="black")
        axs[1].add_feature(cfeature.LAKES, alpha=0.5)
        axs[1].add_feature(cfeature.RIVERS)
    axs[1].set_title(title_frame2)

    # Add colorbar
    cbar = plt.colorbar(
        im1,
        ax=axs.ravel().tolist(),
        orientation="vertical",
        pad=0.02,
        aspect=16,
        shrink=0.8,
    )
    cbar.set_label("Radar Reflectivity (dBz)")

    if title:
        fig.suptitle(title)

    return fig


def plot_single_frame_arso(
    frame, img_type="zm", title=None, cartopy_features=True, **kwargs
):
    """
    Plots a single ARSO frame and returns the figure.

    Parameters
    ----------
    frame : numpy.ndarray
        A [H, W] tensor representing the frame.
    img_type : str, optional
        Image type for colormap selection (default is 'zm').
    title : str, optional
        Title for the plot.
    cartopy_features : bool, optional
        Whether to add cartopy features like borders and rivers (default is True).
    **kwargs
        Additional arguments to pass to `ax.imshow`.

    Returns
    -------
    matplotlib.figure.Figure
        The figure object.
    """
    # Generate longitude and latitude arrays
    lon, lat = get_lon_lat(nx=401, ny=301, dx=1000, dy=1000)

    # Get the appropriate colormap and normalization
    cmap, norm, vmin, vmax = get_cmap(img_type)

    # Update kwargs with custom colormap and normalization
    kwargs.update({"cmap": cmap, "norm": norm, "vmin": vmin, "vmax": vmax})

    # Create the figure with PlateCarree projection
    fig, ax = plt.subplots(
        1,
        1,
        figsize=(16, 12),
        subplot_kw={"projection": ccrs.PlateCarree()},
        dpi=200,
    )

    # Set the extent for the subplot
    extent = [lon.min(), lon.max(), lat.min(), lat.max()]
    ax.set_extent(extent, crs=ccrs.PlateCarree())

    # Plot frame
    im = ax.imshow(
        frame, origin="lower", extent=extent, transform=ccrs.PlateCarree(), **kwargs
    )

    # Add geographical features to the subplot
    if cartopy_features:
        ax.add_feature(
            cfeature.OCEAN.with_scale("50m"),
            alpha=0.5,
            zorder=3,
        )
        ax.add_feature(
            cfeature.BORDERS.with_scale("50m"),
            linewidth=0.3,
            edgecolor="black",
            zorder=3,
        )
        ax.add_feature(
            cfeature.COASTLINE.with_scale("50m"),
            linewidth=0.3,
            edgecolor="black",
            zorder=3,
        )
        ax.add_feature(
            cfeature.LAKES.with_scale("50m"),
            edgecolor="cornflowerblue",
            alpha=0.5,
            linewidth=0.3,
            zorder=3,
        )
        ax.add_feature(
            cfeature.RIVERS.with_scale("50m"),
            edgecolor="cornflowerblue",
            alpha=0.5,
            linewidth=0.3,
            zorder=3,
        )
    if title:
        ax.set_title(title)

    # Add colorbar
    cbar = plt.colorbar(
        im,
        ax=ax,
        orientation="vertical",
        pad=0.02,
        aspect=16,
        shrink=0.8,
    )
    cbar.set_label("Radar Reflectivity (dBz)")

    return fig


def make_animation_comparison_arso(
    gt_frames,  # (T, H, W)
    model1_frames,
    model2_frames,
    gt_title="Ground Truth",
    model1_title="Model 1",
    model2_title="Model 2",
    img_type="zm",
    interval=200,
    cartopy_features=True,
    **imshow_kwargs,
):
    """
    Creates a side-by-side animation of Ground Truth and two models using Cartopy.

    Parameters
    ----------
    gt_frames : np.ndarray
        Shape (T, ny, nx). Ground truth frames.
    model1_frames : np.ndarray
        Shape (T, ny, nx). Frames from the first model.
    model2_frames : np.ndarray
        Shape (T, ny, nx). Frames from the second model.
    gt_title, model1_title, model2_title : str
        Titles for the subplots.
    img_type : str, optional
        Type of image for colormap and normalization.
    interval : int, optional
        Delay between frames in milliseconds.
    cartopy_features : bool
        Whether to add geographical features.
    **imshow_kwargs
        Additional arguments for ax.imshow.

    Returns
    -------
    anim : matplotlib.animation.FuncAnimation
        The animation object.
    """
    T, ny, nx = gt_frames.shape
    lon, lat = get_lon_lat(nx=nx, ny=ny, dx=1000, dy=1000)

    if not (gt_frames.shape == model1_frames.shape == model2_frames.shape):
        raise ValueError("All input frame arrays must have the same shape.")

    cmap, norm, vmin, vmax = get_cmap(img_type)
    imshow_kwargs.update({"cmap": cmap, "norm": norm, "vmin": vmin, "vmax": vmax})

    fig, axs = plt.subplots(
        1,
        3,
        figsize=(20, 5),
        subplot_kw={"projection": ccrs.PlateCarree()},
        gridspec_kw={"wspace": 0.05, "hspace": 0.05},
    )

    extent = [lon.min(), lon.max(), lat.min(), lat.max()]
    titles = [gt_title, model1_title, model2_title]
    data_sources = [gt_frames, model1_frames, model2_frames]
    ims = []

    for i, ax in enumerate(axs):
        ax.set_extent(extent, crs=ccrs.PlateCarree())
        if cartopy_features:
            ax.add_feature(cfeature.OCEAN.with_scale("50m"), alpha=0.5, zorder=3)
            ax.add_feature(
                cfeature.BORDERS.with_scale("50m"),
                linewidth=0.3,
                edgecolor="black",
                zorder=3,
            )
            ax.add_feature(
                cfeature.COASTLINE.with_scale("50m"),
                linewidth=0.3,
                edgecolor="black",
                zorder=3,
            )
            ax.add_feature(
                cfeature.LAKES.with_scale("50m"),
                edgecolor="cornflowerblue",
                alpha=0.5,
                linewidth=0.3,
                zorder=3,
            )
            ax.add_feature(
                cfeature.RIVERS.with_scale("50m"),
                edgecolor="cornflowerblue",
                alpha=0.5,
                linewidth=0.3,
                zorder=3,
            )

        im = ax.imshow(
            data_sources[i][0, :, :],
            origin="lower",
            extent=extent,
            transform=ccrs.PlateCarree(),
            **imshow_kwargs,
        )
        ims.append(im)
        ax.set_title(titles[i])
        ax.set_xticks([])
        ax.set_yticks([])

    # Add a single horizontal colorbar at the bottom
    fig.subplots_adjust(bottom=0.2)
    cbar_ax = fig.add_axes([0.2, 0.15, 0.6, 0.03])
    cbar = fig.colorbar(ims[0], cax=cbar_ax, orientation="horizontal")
    cbar.set_label("Radar Reflectivity (dBZ)")
    if hasattr(norm, "boundaries"):
        unique_sorted_boundaries = sorted(list(set(norm.boundaries)))
        cbar.set_ticks(unique_sorted_boundaries)
        cbar.ax.set_xticklabels([str(int(b)) for b in unique_sorted_boundaries])

    time_text = axs[1].text(
        0.5,
        1.15,
        "",
        transform=axs[1].transAxes,
        ha="center",
        va="bottom",
        fontsize=14,
        color="black",
    )

    def init():
        for i, im in enumerate(ims):
            im.set_data(data_sources[i][0, :, :])
        time_text.set_text("+5 min")
        return ims + [time_text]

    def update(frame_idx):
        for i, im in enumerate(ims):
            im.set_data(data_sources[i][frame_idx, :, :])
        time_text.set_text(f"+{(frame_idx + 1) * 5} min")
        return ims + [time_text]

    anim = FuncAnimation(
        fig, update, init_func=init, frames=range(T), interval=interval, blit=True
    )

    return anim
