import matplotlib.pyplot as plt
import numpy as np

COLORS = [
    "#66c2a5",
    "#fc8d62",
    "#8da0cb",
    "#e78ac3",
    "#a6d854",
    "#ffd92f",
    "#e5c494",
    "#b3b3b3",
]
c0 = COLORS[0]
c1 = COLORS[1]
c2 = COLORS[2]
c3 = COLORS[3]
c4 = COLORS[4]
c5 = COLORS[5]
c6 = COLORS[6]
c7 = COLORS[7]

BPW_LABEL = "Bits Per Weight"
SPARSITY_LABEL = "Sparsity"
INV_SPARSITY_LABEL = "Non-Zero Weights"
DEFAULT_MARKER = "--."
NO_MARKER_LINEWIDTH = 1.1

AXIS_LABELS_CV = [BPW_LABEL, "Top-1 Accuracy"]
AXIS_LABELS_NLP = [BPW_LABEL, "Perplexity"]
DEFAULT_LINESTYLE = "-"
MARKEVERY = 0.02
DASH_STYLE = (3, 1)

COLORS_BARS = {
    "Hessian": "#00678a",
    "GPTQ Quant.": "#984464",
    "OPTQ-RD Quant.": "#56641a",
    "DeepCABAC Encode": "#5eccab",
    "DeepCABAC Decode": "#e6a176",
}

PLOT_PARAMS_METHODS = {
    "uniform": {
        "linestyle": "--",
        "dashes": DASH_STYLE,
        # "marker": "^",
        "marker": None,
        "color": c0,
        "markersize": 2,
        "linewidth": 1,
    },
    "alpha_inv_tr": {
        "linestyle": "solid",
        # "dashes": (5, 2),
        # "marker": "o",
        "color": c1,
        # "markersize": 1,
        "linewidth": 1,
    },
    "gptq": {
        "linestyle": DEFAULT_LINESTYLE,
        "marker": "x",
        "color": c2,
        "markersize": 4,
    },  # , "markersize": 2},
    "gptq_vanilla": {
        "linestyle": DEFAULT_LINESTYLE,
        "marker": "d",
        "color": "grey",
        # "markersize": ,
        "markersize": 3,
    },
    "gptq_bz2": {
        "linestyle": DEFAULT_LINESTYLE,
        "marker": "d",
        "color": "grey",
        # "markersize": ,
        "markersize": 3,
    },
    "nncodec": {
        "linestyle": DEFAULT_LINESTYLE,
        "marker": "s",
        "color": c4,
        "markersize": 2,
    },
    "direct_rd": {
        "linestyle": DEFAULT_LINESTYLE,
        "marker": "<",
        "color": c7,
        "markersize": 2,
        "markevery": MARKEVERY,
    },
    "base_performance": {
        "linestyle": "-.",
        "color": "black",
    },
}

PLOT_LABELS = {
    "uniform": "OPTQ-RD (ours)",
    "alpha_inv_tr": "OPTQ-RD 1/tr (ours)",
    "gptq": "OPTQ+DeepCABAC",
    "gptq_vanilla": "OPTQ",
    "gptq_bz2": "OPTQ+BZ2",
    "nncodec": "NNCodec",
    "direct_rd": "Direct RD",
}

COLORS_NBATCHES = [
    # "#fff7ec",
    "#a6cee3",
    "#1f78b4",
    "#b2df8a",
    "#33a02c",
    "#fb9a99",
    "#e31a1c",
    "#fdbf6f",
]
NO_MARKERS = True
# access in a sorted manner
PLOT_PARAMS_NBATCHES = [
    {
        "linestyle": DEFAULT_LINESTYLE,
        "marker": "^" if not NO_MARKERS else None,
        "color": COLORS_NBATCHES[0],
        "markersize": 1.3,
        "linewidth": NO_MARKER_LINEWIDTH,
    },
    {
        "linestyle": DEFAULT_LINESTYLE,
        "marker": "." if not NO_MARKERS else None,
        "color": COLORS_NBATCHES[1],
        "markersize": 1.3,
        "linewidth": NO_MARKER_LINEWIDTH,
    },
    {
        "linestyle": DEFAULT_LINESTYLE,
        "marker": "x" if not NO_MARKERS else None,
        "color": COLORS_NBATCHES[2],
        "markersize": 1.3,
        "linewidth": NO_MARKER_LINEWIDTH,
    },
    {
        "linestyle": DEFAULT_LINESTYLE,
        "marker": "s" if not NO_MARKERS else None,
        "color": COLORS_NBATCHES[3],
        "markersize": 1.3,
        "linewidth": NO_MARKER_LINEWIDTH,
    },
    {
        "linestyle": DEFAULT_LINESTYLE,
        "marker": "d" if not NO_MARKERS else None,
        "color": COLORS_NBATCHES[4],
        "markersize": 1.3,
        "linewidth": NO_MARKER_LINEWIDTH,
    },
    {
        "linestyle": DEFAULT_LINESTYLE,
        "marker": "o" if not NO_MARKERS else None,
        "color": COLORS_NBATCHES[5],
        "markersize": 1.3,
        "linewidth": NO_MARKER_LINEWIDTH,
    },
    {
        "linestyle": DEFAULT_LINESTYLE,
        "marker": "v" if not NO_MARKERS else None,
        "color": COLORS_NBATCHES[6],
        "markersize": 1.3,
        "linewidth": NO_MARKER_LINEWIDTH,
    },
]

TRAIN_TEST_SHIFT_PARAMS = {
    "gptq_imagenet": PLOT_PARAMS_METHODS["gptq"],
    "gptq_coco": PLOT_PARAMS_METHODS["gptq"] | {"color": c7, "marker": "v"},
    # "uniform_coco": PLOT_PARAMS_METHODS["uniform"] | {"color": c6, "marker": "h"},
    "uniform_coco": PLOT_PARAMS_METHODS["uniform"]
    | {
        "color": c6,
        "linestyle": "-." if NO_MARKERS else "-",
        "dashes": (None, None),
        "marker": "h" if not NO_MARKERS else None,
    },
    "uniform_imagenet": PLOT_PARAMS_METHODS["uniform"],
    "nncodec": PLOT_PARAMS_METHODS["nncodec"],
}

TRAIN_TEST_SHIFT_LABELS = {
    "gptq_imagenet": "OPTQ Imagenet",
    "gptq_coco": "OPTQ COCO",
    "uniform_coco": "OPTQ-RD COCO",
    "uniform_imagenet": "OPTQ-RD Imagenet",
    "nncodec": PLOT_LABELS["nncodec"],
}

fwidth = 3.25
fheight = 2.0086104634371584
rcparams = {
    "text.usetex": True,
    "font.family": "serif",
    "text.latex.preamble": "\\usepackage{times} ",
    "figure.figsize": (3.25, 2.0086104634371584),
    "figure.constrained_layout.use": True,
    "figure.autolayout": False,
    "savefig.bbox": "tight",
    "savefig.pad_inches": 0.015,
    "font.size": 8,
    "axes.labelsize": 8,
    "legend.fontsize": 6,
    "xtick.labelsize": 6,
    "ytick.labelsize": 6,
    "axes.titlesize": 8,
    "figure.dpi": 500,
    "axes.grid": True,
    "grid.alpha": 0.5,
    "lines.markersize": 2,
    "lines.linewidth": 0.7,
    # "grid.color": "black",
}
plt.rcParams.update(**rcparams)
