from typing import Optional

import torch
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as tkr
import matplotlib.tri as tri
import imageio
from tqdm import tqdm


def plot_1xn_meshed(
    mesh_coords,
    channels,
    mesh_fields: Optional[torch.Tensor] = None,
    title: str = "",
    pred: Optional[torch.Tensor] = None,
    error: Optional[torch.Tensor] = None,
):
    n_keys = len(channels)
    n_cols = 1
    if pred is not None:
        n_cols += 1
    if error is not None:
        n_cols += 1

    fig, axes = plt.subplots(
        n_keys,
        n_cols,
        figsize=(6 * n_cols, 2 * n_keys),
        squeeze=False,
        layout="constrained",
    )

    triangles = tri.Triangulation(mesh_coords[:, 0], mesh_coords[:, 1])

    for idx, (key, sl) in enumerate(channels.items()):
        # Extract slices
        input_slice = mesh_fields[:, sl].mean(1)
        col_idx = 0

        # Calculate global scales
        vmin = input_slice.min().item()
        vmax = input_slice.max().item()

        if pred is not None:
            # print(pred[:, sl].shape)
            pred_slice = pred[:, sl].mean(1)
            vmin = min(vmin, pred_slice.min().item())
            vmax = max(vmax, pred_slice.max().item())

        # Plot input
        ax = axes[idx, col_idx]
        im = ax.tripcolor(
            triangles, input_slice, shading="flat", cmap="viridis", vmin=vmin, vmax=vmax
        )
        add_colorbar(ax, im, vmin, vmax)
        ax.set_title(f"GT: {key}")
        format_axis(ax)
        col_idx += 1

        # Plot prediction
        if pred is not None:
            ax = axes[idx, col_idx]
            im = ax.tripcolor(
                triangles,
                pred_slice,
                shading="flat",
                cmap="viridis",
                vmin=vmin,
                vmax=vmax,
            )
            add_colorbar(ax, im, vmin, vmax)
            ax.set_title(f"Pred: {key}")
            format_axis(ax)
            col_idx += 1

        # Plot error
        if error is not None:
            ax = axes[idx, col_idx]
            error_slice = error[:, sl].mean(1)
            im = ax.tripcolor(
                triangles,
                error_slice,
                shading="flat",
                cmap="coolwarm",
                vmin=error_slice.min(),
                vmax=error_slice.max(),
            )
            cbar = fig.colorbar(im, ax=ax, format=tkr.FormatStrFormatter("%.2g"))
            cbar.set_ticks([error_slice.min(), 0, error_slice.max()])
            avg_error = error_slice.mean()
            ax.set_title(f"Error: {key} (Avg={avg_error:.2e})")
            format_axis(ax)
            col_idx += 1

    fig.suptitle(title)
    return fig


def add_colorbar(ax, im, vmin, vmax, shrink=1):
    cbar = plt.colorbar(im, ax=ax, format=tkr.FormatStrFormatter("%.2g"), shrink=shrink)
    cbar.set_ticks([vmin, (vmin + vmax) / 2, vmax])


def format_axis(ax):
    ax.set_xlabel("X")
    ax.set_ylabel("Y")
    # ax.set_aspect("equal")
    ax.grid(False)


def plot_1xn_meshed_forming(
    initial_mesh_coords,
    mesh_coords,
    channels,
    mesh_fields: Optional[torch.Tensor] = None,
    title: str = "",
    pred: Optional[torch.Tensor] = None,
    error: Optional[torch.Tensor] = None,
):
    n_keys = len(channels)
    n_cols = 1
    if pred is not None:
        n_cols += 1
    if error is not None:
        n_cols += 1

    fig, axes = plt.subplots(
        n_keys,
        n_cols,
        figsize=(6 * n_cols, 2 * n_keys),
        squeeze=False,
        layout="constrained",
    )

    triang_orig = tri.Triangulation(initial_mesh_coords[:, 0], initial_mesh_coords[:, 1])
    triangles = tri.Triangulation(
        mesh_coords[:, 0], mesh_coords[:, 1],
        triangles=triang_orig.triangles
    )

    for idx, (key, sl) in enumerate(channels.items()):
        # Extract slices
        input_slice = mesh_fields[:, sl].mean(1)
        col_idx = 0

        # Calculate global scales
        vmin = input_slice.min().item()
        vmax = input_slice.max().item()

        if pred is not None:
            # print(pred[:, sl].shape)
            pred_slice = pred[:, sl].mean(1)
            vmin = min(vmin, pred_slice.min().item())
            vmax = max(vmax, pred_slice.max().item())

        # Plot input
        ax = axes[idx, col_idx]
        im = ax.tripcolor(
            triangles, input_slice, shading="flat", cmap="viridis", vmin=vmin, vmax=vmax
        )
        add_colorbar(ax, im, vmin, vmax, 0.5)
        ax.set_title(f"GT: {key}")
        format_axis(ax)
        col_idx += 1

        # Plot prediction
        if pred is not None:
            ax = axes[idx, col_idx]
            im = ax.tripcolor(
                triangles,
                pred_slice,
                shading="flat",
                cmap="viridis",
                vmin=vmin,
                vmax=vmax,
            )
            add_colorbar(ax, im, vmin, vmax, 0.5)
            ax.set_title(f"Pred: {key}")
            format_axis(ax)
            col_idx += 1

        # Plot error
        if error is not None:
            ax = axes[idx, col_idx]
            error_slice = error[:, sl].mean(1)
            im = ax.tripcolor(
                triangles,
                error_slice,
                shading="flat",
                cmap="coolwarm",
                vmin=error_slice.min(),
                vmax=error_slice.max(),
            )
            cbar = fig.colorbar(im, ax=ax, format=tkr.FormatStrFormatter("%.2g"), shrink=0.5)
            cbar.set_ticks([error_slice.min(), 0, error_slice.max()])
            avg_error = error_slice.mean()
            ax.set_title(f"Error: {key} (Avg={avg_error:.2e})")
            format_axis(ax)
            col_idx += 1

    fig.suptitle(title)
    plt.savefig("figure.png", dpi=600, bbox_inches="tight")
    return fig
