import math
import numpy as np
import torch
import colorsys
import kdtpp.prob as prob
import matplotlib as mpl
import einops
import xml
import io

"""
copy/paste into a notebook

def init_matplotlib():
    plt.close("all")
    plt.style.use(["science", "ieee"])
    plt.rcParams.update({
        "font.size": FONTSIZE,
        "legend.fontsize": FONTSIZE,
        "axes.grid": False,
        "figure.dpi": 96*3,
        "svg.fonttype": "none",
        "xtick.top": False,
        # science sets these, but we want manual control.
        "savefig.bbox": None,
        "savefig.pad_inches": 0.0,
        "figure.constrained_layout.use": False,
        "figure.autolayout": False
    })
"""

def in_to_cm(x):
    return 2.54 * x


def text_to_in(n_chars, font_size_pt, avg_em=0.5, n_spaces=0):
    """Roughly guess how many inches characters text will take up."""
    space_em = 0.30
    em_total = (n_chars * avg_em) + (n_spaces * (space_em - avg_em))
    return em_total * (font_size_pt / 72.0)


# def show_svg(path):
#     IPython.display.display(IPython.display.SVG(filename=path))

def num_to_color(x, cmap="cividis"):
    """From a colormap, get the color for a value in [0, 1]."""
    if isinstance(cmap, str):
        cmap = mpl.colormaps.get_cmap(cmap)
    norm = mpl.colors.Normalize(vmin=0, vmax=1)
    rgba = cmap(norm(x))
    hex_color = mpl.colors.to_hex(rgba)
    return hex_color


def hsv_to_hex(h, s, v):
    assert np.all(np.array([h,s,v]) <= 1.0)
    res = "#" + "".join(
        f"{round(255.0*i):02X}" for i in colorsys.hsv_to_rgb(h, s, v)
    )
    return res

def num_to_human_str(num):
    magnitude = 0
    mag_suffixes = ["", "k", "M", "G", "T"]
    while abs(num) >= 1000 and magnitude < len(mag_suffixes) - 1:
        magnitude += 1
        num /= 1000.0
    res = f"{num:.0f}{mag_suffixes[magnitude]}"
    return res

def set_short_tick_formatter(axis, threshold=2000):
    """
    Apply 'k' formatting to the given axis only if any major tick >= threshold.

    """
    def _formatter(v, _):
        return num_to_human_str(v)

    ticks = axis.get_majorticklocs()
    if any(abs(t) >= threshold for t in ticks):
        axis.set_major_formatter(mpl.ticker.FuncFormatter(_formatter))

class SafeEdgeLocator(mpl.ticker.Locator):
    """
    Drop axis ticks when they get too close to the edge of the figure.

    This is useful when you are saving figures with tight bounding boxes
    (tight or constrained layout) and you are relying on the inner-figure
    plotting area to determine the size of the figure. This can be troublesome
    if axis tick labels are too close to the edge, as they can peek out
    over the edge of the plotting area. This locator drops ticks close to 
    the edge.
    """
    def __init__(self, base_locator=None, *, clearance_in=0.05, axis="x"):
        self.base = base_locator or mpl.ticker.MaxNLocator(
            nbins='auto', steps=[1, 2, 2.5, 5, 10])
        self.clearance_in = float(clearance_in)
        self.axis_name = axis

    def tick_values(self, vmin, vmax):
        vals = self.base.tick_values(vmin, vmax)
        ax = self.axis.axes
        fig = ax.figure
        bbox = ax.get_position()
        fig_w, fig_h = fig.get_size_inches()

        if self.axis_name == "x":
            axis_len_in = bbox.width * fig_w
        else:
            axis_len_in = bbox.height * fig_h

        if axis_len_in <= 0:
            return vals  # no scaling possible

        # Convert clearance in inches to fractional clearance
        frac_clear = self.clearance_in / axis_len_in

        # Filter ticks
        lo, hi = (vmin, vmax) if vmin < vmax else (vmax, vmin)
        span = hi - lo if hi != lo else 1.0

        def inside(v):
            t = (v - lo) / span
            return (t <= 1.0 - frac_clear)

        return [v for v in vals if inside(v)]

    def __call__(self):
        vmin, vmax = self.axis.get_view_interval()
        return self.tick_values(vmin, vmax)



def fig_to_xml(fig, save_kwgs=None):
    """Convert a matplotlib figure to an SVG XML tree."""
    if save_kwgs is None:
        save_kwgs = {}
    buf = io.StringIO()
    fig.savefig(buf, format="svg", **save_kwgs)
    buf.seek(0)
    tree = xml.etree.ElementTree.parse(buf)
    tree.getroot().set("width", tree.getroot().get("width").replace("pt", ""))
    tree.getroot().set("height", tree.getroot().get("height").replace("pt", ""))
    return tree


def composite_legend(fig, legend_path, out_path):
    """Manually composite a legend onto a figure.

    Args:
        fig: matplotlib figure object.
        legend_path: path to a legend SVG that contains a group with id="legend".
        out_path: path to save the composite figure.
    """
    fig_tree = fig_to_xml(fig)
    legend_tree = xml.etree.ElementTree.parse(legend_path)
    legend_item = legend_tree.find('.//{http://www.w3.org/2000/svg}g[@id="legend"]')
    assert legend_item, print(list(legend_tree.getroot().iter()))
    fig_tree.getroot().append(legend_item)
    fig_tree.write(out_path)


def logmix_dist_plots(log_tau, mu, log_sigma, target_t, max_n_plots):
        """Visualize logmix head output.

        This function is a bit of a mess. We first calculate roughly the
        range of the distribution, and send in a batch of timesteps to
        query this range. On this first request, we also query the value for
        the ground truth timestep, so that it can be labeled on the figure.
        The range is normally way too big, so we clip to a sub-interval that
        contains 98% of the probability mass. We then send another batch of
        timesteps that covers this sub-interval.

        TODO: could refactor to remove max_n_plots, or even to remove the 
        inner loop altogether, and make the caller do it.
        """
        B, M  = log_tau.shape
        assert log_tau.shape == mu.shape == log_sigma.shape == (B, M), (B, M)
        assert target_t.shape == (B,)
        median = prob.logmix_median(log_tau, mu, log_sigma)
        query_B = max(2048, B)  # can be much larger than B
        eps = 1e-4  # simultaneous events not allowed.
        tau_normed = torch.softmax(log_tau, dim=1)
        min_tau = math.log(1 / log_tau.shape[1])
        # Should always be at least 1 (you can't have N elements below 1/N).
        consider = tau_normed >= min_tau
        # fmt: off
        # Remember, mu, sigma are in log space (log_sigma is log of log).
        t_min = torch.where(consider, mu - 3 * log_sigma.exp(), torch.inf).min(dim=1)[0].exp().clamp(min=eps)
        t_max = torch.where(consider, mu + 3 * log_sigma.exp(), 0).max(dim=1)[0].exp()
        # At this point, we should have a rough idea of t_min, t_max based on 
        # the mixtures that have a significant weight. We could still have an
        # issue if these mixtures have unsuitable (inf, nan) parameters.
        # So here we provide last resort hard-coded values. They prevent 
        # exceptions being thrown by the logmix_log_prob function.
        t_min = torch.where(torch.isfinite(t_min), t_min, 0)
        t_max = torch.where(torch.isfinite(t_max), t_max, 100)
        # fmt: on
        figs = []
        for i in range(min(max_n_plots, B)):
            _repeat = lambda x: einops.repeat(x, "m -> b m", b=query_B)
            # TODO: is this how we want to deal with the remaining (no spike) segments?
            # Currently, we only plot if there are spikes.
            ts = torch.linspace(t_min[i], t_max[i], steps=query_B - 1)
            if target_t[i] > 0:
                ts = torch.cat([ ts, torch.tensor([target_t[i]])]).to(log_tau.device)
            probs = prob.logmix_log_prob(
                _repeat(log_tau[i]), _repeat(mu[i]), _repeat(log_sigma[i]), ts
            ).exp()
            actual_prob = probs[-1].item()
            cum = torch.cumulative_trapezoid(probs, dx=(ts[1] - ts[0]).item())
            # Clip to the 96% body.
            from_idx = torch.searchsorted(cum, 0.02)
            to_idx = torch.searchsorted(cum, 0.98)
            ts_v2 = torch.linspace(
                ts[from_idx], ts[to_idx], steps=query_B, device=log_tau.device
            )
            probs = prob.logmix_log_prob(
                _repeat(log_tau[i]),
                _repeat(mu[i]),
                _repeat(log_sigma[i]),
                ts_v2,
            ).exp()
            fig = mpl.figure.Figure(figsize=(5, 5))
            # One for plot, and one for table.
            axs = fig.subplots(2, 1)
            plot_ax, table_ax = axs[0], axs[1]
            plot_ax.plot(ts_v2.cpu().numpy(), probs.cpu().numpy())
            plot_ax.axvline(
                target_t[i].item(),
                color="r",
                linestyle="--",
            )
            plot_ax.axvline(
                median[i].item(),
                color="g",
                linestyle="--",
            )
            # annotation for the line
            plot_ax.annotate(
                f"p({target_t[i].item():.3g})\n= {actual_prob:.3f}",
                xy=(target_t[i].item(), plot_ax.get_ylim()[1] * 0.9),
                textcoords="offset points",
                xytext=(2, 0),
                color="r",
            )

            def add_component_data():
                max_n = 10
                assert tau_normed.shape[0] == B
                taus = tau_normed[i].cpu().numpy()
                mus = mu[i].cpu().numpy()
                sigmas = log_sigma[i].exp().cpu().numpy()
                tau_idxs = np.arange(taus.shape[0])
                vals = np.array(
                    sorted(
                        zip(tau_idxs, taus, mus, sigmas), key=lambda x: x[1], reverse=True
                    )
                )[0:max_n]
                fmt = "{:.4f}"
                val_strs = [(str(int(r[0])), fmt.format(r[1]), fmt.format(r[2]), fmt.format(r[3])) for r in vals]
                table_ax.table(
                    cellText=val_strs,
                    colLabels=["τ_id", "τ", "μ", "σ"],
                    cellLoc="center",
                    loc="center",
                )
                table_ax.axis("off")

            add_component_data()
            plot_ax.set_title(f"t={target_t[i].item():.3g}")
            figs.append(fig)

        return figs
    

