import colorsys

from matplotlib import rc
import matplotlib.pyplot as plt
import plotnine as p9
import math
import numpy as np

label_H_Z = r"$\operatorname{H}[Z]$"

long_label_decoder_XE = r"Decoder Cross-Entropy $\operatorname{H}_\theta[Y|Z]$"
long_label_residual_information = r"Residual Information $\operatorname{I}[X;Y \mathbin{\vert} Z]$"
long_label_preserved_information = r'Preserved Information $\operatorname{I}[X;Z]$'
long_label_encoding_entropy = r"Encoding Entropy $\operatorname{H}[Z]$"


short_label_decoder_XE = r"$\operatorname{H}_\theta[Y|Z]$"
short_label_residual_information = r"$\operatorname{I}[X;Y \mathbin{\vert} Z]$"
short_label_preserved_information = r'$\operatorname{I}[X;Z]$'
short_label_encoding_entropy = r"$\operatorname{H}[Z]$"


regularizer_shapes = {
    "entropy_via_variance_Z": "^",
    "entropy_via_variance_Z__Y": "v",
    "mean_squared_Z": "o",
    "kraskov_Z": ">",
    "kraskov_Z__Y": "<",
    "weight_decay": "s",
}

regularizer_labels = {
    "entropy_via_variance_Z": r"$\log \operatorname{Var} \left [Z \right ]$",
    "entropy_via_variance_Z__Y": r"$\log \operatorname{Var} \left [ Z \mathbin{\vert} Y \right ]$",
    "mean_squared_Z": r"$\operatorname{\mathbb{E}} Z^2$",
    # "kraskov_Z": r"Density estimate $Z$",
    # "kraskov_Z__Y": r"Density estimate $Z \vert Y$",
    "weight_decay": "Weight Decay",
}

objective_labels = {
    "decoder_uncertainty": "$\min H[Y|Z]$",
    "prediction": r"$\min H_\theta[Y|X]$",  # "\text{Prediction Cross-Entropy} \min H_\theta[Y|X]",
    "decoder": r"$\min H_\theta[Y|Z]$",  # "\text{Decoder Cross-Entropy} \min H_\theta[Y|Z]",
}

variable_labels = {
    "xe_decoder": r"$H_\theta[Y|Z]$",
    "xe_prediction": r"$H_\theta[Y|X]$",
    "discrete_decoder_uncertainty": "$H[Y|Z]$",
    "error_p": r"$\mathbb{E}_{\hat{p}(x, y)} p(\hat{y} \neq y)$",
    "error_bound": r"$1 - \exp(H[Y|Z])$",
    "error_bound_decoder_xe": r"$1 - \exp(H_\theta[Y|Z])$",
    "error": r"$\arg \max$ error",
}

COLOR_DARK_BLUE = "#005aff"
COLOR_BRIGHT_BLUE = "#00a2ff"
COLOR_DARK_ORANGE = "#ff6c00"
COLOR_BRIGHT_ORANGE = "#ffb000"
COLOR_PINK = "#ff339f"
COLOR_GREEN = "#009926"
COLOR_PURPLE = "#66001d"
COLOR_GRAY = "#bebebe"

variable_colors = {
    "xe_decoder": COLOR_BRIGHT_BLUE,
    "xe_prediction": COLOR_PINK,
    "discrete_decoder_uncertainty": COLOR_DARK_BLUE,
    "error_p": COLOR_PURPLE,
    "error_bound": COLOR_GREEN,
    "error_bound_decoder_xe": COLOR_GREEN,
    "error": COLOR_GRAY,
}

regularizer_colors = {
    "entropy_via_variance_Z": COLOR_DARK_BLUE,
    "entropy_via_variance_Z__Y": COLOR_DARK_ORANGE,
    "mean_squared_Z": COLOR_GREEN,
    # "kraskov_Z": COLOR_BRIGHT_BLUE,
    # "kraskov_Z__Y": COLOR_BRIGHT_ORANGE,
    "weight_decay": COLOR_GRAY,
}

variable_order = [
    "discrete_decoder_uncertainty",
    "xe_decoder",
    "xe_prediction",
    "error_p",
    "error_bound",
    "error_bound_decoder_xe",
    "error",
]

objective_order = ["decoder_uncertainty", "decoder", "prediction"]

source_labels = {
    "train": "Train Set",
    "test": "Test Set",
    "combined": "IP"
}

IB_variable_labels = {
    # "xe_decoder": r'Residual Information $\operatorname{I}[Y;X \mathbin{\vert} Z]$/Decoder Uncertainty
    # $\operatorname{H_\theta}[Y|Z]$',
    "xe_decoder": r"Decoder Cross-Entropy $\operatorname{H_\theta}[Y|Z]$",
    "xe_prediction": r"Prediction Cross-Entropy $\operatorname{H_\theta}[Y|X]$",
    "continuous_H_Z": r"Preserved Information $\operatorname{I}[X;Z]$",
}

IB_variable_colors = {
    "xe_decoder": COLOR_DARK_BLUE,
    "xe_prediction": COLOR_PINK,
    "continuous_H_Z": COLOR_DARK_ORANGE,
    "correct_prob": COLOR_GREEN,
}

IBR_variable_labels = {
    'continuous_H_Z': r'$H [ Z ]$',
    'continuous_H_Z__Y': r'$ H [Z | Y ] $',
    'mean_squared_Z': r"$\operatorname{\mathbb{E}} Z^2$"
}


def create_hue_gradient():
    hue_color_points = [colorsys.hls_to_rgb((1 + 0.65 + i / 33) % 1, 0.6, 0.7) for i in range(0, 29)]
    for i, dst_luminance in enumerate(np.linspace(0.4, 0.9, num=29)):
        rgb = list(hue_color_points[i])

        luminance = 0.2126 * rgb[0] + 0.7152 * rgb[1] + 0.0722 * rgb[2]

        for j in range(3):
            rgb[j] /= luminance
            rgb[j] *= dst_luminance
            rgb[j] /= 1.1845138579738865

        hue_color_points[i] = rgb
    return hue_color_points


hue_gradient = create_hue_gradient()

# MatPlotLib style and settings (that also affect PlotNine)

rc("text", usetex=True)
rc('text.latex',
   preamble=r"""
\usepackage{bbm}
\usepackage{amsmath}
\usepackage{amsfonts}
""")

plt.style.use(["seaborn", "seaborn-colorblind", "seaborn-paper", "seaborn-white"])


def int_labeller(values):
    return [f"${int(value)}$" for value in values]


def get_pow10(value, exponent):
    if exponent < 0:
        return value / 10**-exponent
    else:
        return value * 10 ** exponent


def pow10_labeller(points):
    labels = []
    for point in points:
        # print(point)
        exponent = math.floor(math.log10(point))
        value = point / 10 ** exponent
        # print(exponent)
        if abs(value - 1.) < 1e-5:
            if exponent == 0:
                labels.append("$1$")
            else:
                labels.append(f"$10^{{{exponent}}}$")
        elif abs(exponent) < 4:
            labels.append(f"${get_pow10(round(value, 1), exponent)}$")
        else:
            if exponent == 0:
                labels.append(f"${value:.0f}$")
            else:
                labels.append(f"${value:.0f} \\times 10^{{{exponent}}}$")

        print(f"{point} -> {labels[-1]}")

    return labels


def plt_save(output_path, **kwargs):
    plt.style.use("seaborn-colorblind")
    plt.savefig(output_path, transparent=True, bbox_inches="tight", pad_inches=0, **kwargs)


# PlotNine utils


def p9_save(plot, filename, **kwargs):
    plot.save(filename, transparent=True, bbox_inches="tight", pad_inches=0, **kwargs)


p9_base_theme = p9.theme_seaborn(style="whitegrid", context="paper")
p9.theme_set(p9_base_theme)
