from typing import Optional, Sequence
import torch
import numpy as np
import matplotlib.pyplot as plt
from abc import ABC, abstractmethod


def plot_df_cv(
    df, label=None, symbol="--.", key_x="entropy", key_y="acc", color=None, ax=None
):
    if label is None:
        label = df.method.iloc[0]

    plt.plot(df[key_x], df[key_y], symbol, label=label, color=color)
    plt.xlabel("Entropy")
    plt.ylabel("Accuracy")


def grid_plot(
    quant_weight: torch.Tensor,
    grid_lines: bool = True,
    padding: float = 1,
    ax=None,
    **kwargs,
):
    """Creates a plot of a quantised weight tensor. Each bar corresponds to exactly one grid point,
    with the height of the bar corresponding to the number of weights that were quantised to that grid point.
    The width of the bar covers exactly those points on the x-axis, that would be quantised to the grid-point.
    The grid points are shown as dashed lines in red, but can also be turned off.

    Args:
        quant_weight (torch.Tensor): The quantised weight tensor to plot.
        grid_lines (bool, optional): Whether to plot grid lines. Defaults to True.
        padding (float, optional): The padding to add to the left and right of the grid. Defaults to 1.
        ax: The axis to plot on. If None, a new figure is created.
    """
    if ax is None:
        ax = plt.gca()
    if quant_weight.unique().numel() > 0.95 * quant_weight.numel():
        print(
            "WARNING: The weight tensor is almost continuous. This might not be the intended use of this function."
        )
    grid_points, counts = quant_weight.detach().unique(return_counts=True)

    # right = diff_left + pad
    diff = [a.item() for a in (grid_points[1:] - grid_points[:-1])]
    diff_left = np.array([padding] + diff)
    diff_right = np.array(diff + [padding])

    gwidth = (diff_left + diff_right) / 2
    gborders = (grid_points - diff_left / 2).numpy()
    gpositions = gwidth / 2 + gborders

    counts_total = counts.sum()
    # sum_i height_i * width_i = 1
    scaling = np.array([1 / (w * counts_total) for c, w in zip(counts, gwidth)])
    counts_scaled = counts * scaling
    assert ((counts_scaled * gwidth).sum() - 1).abs() < 1e-5
    ax.bar(gpositions, counts_scaled, width=gwidth, **kwargs)

    counts_scaled.max()
    if grid_lines:
        for point in grid_points.detach().flatten():
            ax.vlines(
                point,
                0,
                counts_scaled.max() * 1.1,
                "r",
                linestyles="dashed",
                linewidth=0.5,
            )


class Histogram(ABC):
    """Interface for histogram objects."""

    def __init__(self, data: torch.Tensor) -> None:
        self.data = data.flatten()

    def get_xlim(self, data_percent: float = 0.95):
        lower = int(
            torch.floor(self.data.numel() * (1 - torch.tensor(data_percent)) / 2)
        )
        upper = int(self.data.numel() - lower)
        return (self.data[lower], self.data[upper])

    @abstractmethod
    def plot(self, ax=None):
        """Plot the histogram."""
        if ax is None:
            ax = plt.gca()


class QuantizedHistogram(Histogram):
    def __init__(self, data: torch.Tensor, **kwargs) -> None:
        super().__init__(data, **kwargs)

    def plot(self, ax=None):
        super().plot(ax)
        grid_plot(self.data, ax=ax)


class UnquantizedHistogram(Histogram):
    def __init__(self, data: torch.Tensor, bins="auto", **kwargs) -> None:
        super().__init__(data, **kwargs)
        self.bins = bins

    def plot(self, ax=None):
        super().plot(ax)
        if ax is None:
            ax = plt.gca()
        ax.hist(self.data.flatten().numpy(), bins=self.bins, density=True)


class RdQuantizedHistogram(Histogram):
    def __init__(
        self, data: torch.Tensor, original_rated_grid: torch.Tensor, **kwargs
    ) -> None:
        super().__init__(data, **kwargs)
        self.original_rated_grid = original_rated_grid.flatten()

    def plot(self, ax=None):
        super().plot(ax)
        grid_plot(self.data, ax=ax, color="blue")
        grid_plot(
            self.original_rated_grid, ax=ax, color="orange", alpha=0.5, grid_lines=False
        )


def display_histogram_grid_quantised_nets(
    original_net: torch.nn.Module,
    quant_net: torch.nn.Module,
    rd_quantised_net: torch.nn.Module,
    xlim: Optional[tuple[float, float]] = None,
):
    histograms_unquantised = []
    histograms_quantised = []
    histograms_rd_quantised = []

    for n, p in original_net.named_parameters():
        p = p.detach()
        qp = quant_net.state_dict()[n].detach()
        rp = rd_quantised_net.state_dict()[n].detach()
        histograms_unquantised.append(UnquantizedHistogram(p))
        histograms_quantised.append(QuantizedHistogram(qp))
        histograms_rd_quantised.append(RdQuantizedHistogram(rp, p))
    all_histograms = (
        histograms_unquantised + histograms_quantised + histograms_rd_quantised
    )
    axes = display_histogram_grid(
        all_histograms,
        num_rows=3,
        num_cols=len(histograms_unquantised),
        show=False,
    )
    if xlim is not None:
        # xlim = (-np.inf, np.inf)
        # for h in all_histograms:
        #    low, high = h.get_xlim()
        #    if low > xlim[0]:
        #        xlim = (low, xlim[1])
        #    if high < xlim[1]:
        #        xlim = (xlim[0], high)
        for ax in axes:
            ax.set_xlim(xlim)
    plt.show()


def display_histogram_grid(
    histograms: Sequence[Histogram],
    num_rows: int,
    num_cols: int,
    figsize: tuple[int, int] = (12, 8),
    titles: Optional[list[str]] = None,
    show: bool = True,
) -> list:
    """
    Display a grid of histograms.

    Parameters:
    - data: A list of arrays, where each array contains the data for one histogram.
    - num_rows: Number of rows in the grid.
    - num_cols: Number of columns in the grid.
    - figsize: Tuple specifying the size of the figure (default is (12, 8)).
    - bins: Number of bins to use in the histograms (default is 10).
    - titles: A list of strings specifying titles for each histogram (default is None).
              If provided, must have the same length as data.
    """
    total_plots = num_rows * num_cols
    if len(histograms) > total_plots:
        raise ValueError(
            f"Number of data arrays exceeds the available grid size: {total_plots} < {len(data)}"
        )
    if len(histograms) < total_plots:
        print(
            f"Warning: Number of data arrays is less than the available grid size: {total_plots} > {len(data)}"
        )

    fig, axes = plt.subplots(num_rows, num_cols, figsize=figsize)
    axes = axes.flatten()

    for i, ax in enumerate(axes):
        if i < len(histograms):
            histograms[i].plot(ax=ax)
            if titles is not None:
                ax.set_title(titles[i])
        else:
            ax.axis("off")  # turn off empty subplots

    plt.tight_layout()
    if show:
        plt.show()
    return axes


if __name__ == "__main__":
    # Example usage:
    data = [np.random.normal(loc=mu, scale=1.0, size=100) for mu in range(4)]
    data_q = [torch.tensor(d).round() for d in data]
    data_shifted = [torch.tensor(d + 1.2).round() for d in data]

    histograms = []
    histograms.extend([UnquantizedHistogram(torch.tensor(d), bins=20) for d in data])
    histograms.extend([QuantizedHistogram(d) for d in data_q])
    histograms.extend(
        [RdQuantizedHistogram(d, ds) for d, ds in zip(data_q, data_shifted)]
    )
    titles = [f"Data {i+1}" for i in range(len(histograms))]  # Example titles
    display_histogram_grid(histograms, 3, 4, figsize=(12, 8), titles=titles)
