# PYTHONPATH=. python3 analyses/num_parameters_plot.py --dim 11
# PYTHONPATH=. python3 analyses/num_parameters_plot.py --dim 12 --fig_format pgf
# cp plots/num_parameters_dim9.pgf submission/used_plots/
# cp plots/num_parameters_dim11.pgf submission/used_plots/
"""
cp plots/d1_plot_12.pgf submission/used_plots/
cp plots/model_index_size_12.pgf submission/used_plots/
cp plots/model_size_error_12.pgf submission/used_plots/

python3 num_parameters_plot.py --dim 12 --fig_scl 1.7

"""

import argparse
import math
import os
import pickle
import logging
import warnings
import sys
from typing import Iterable, List, Tuple

import seaborn
import torch
import numpy as np
import matplotlib.lines
import matplotlib.pyplot as plt

sys.path.append("")

import tools
import networks
import depth_analysis

import utils.plotting
import utils.path_config
import utils.logging

warnings.filterwarnings("error")

standard_streamhandler = utils.logging.get_standard_streamhandler()

logging_level = 15

logger = logging.getLogger(__name__)
logger.setLevel(logging_level)
logger.addHandler(standard_streamhandler)

FigAx = Tuple[matplotlib.figure.Figure, np.array]

palette_name = "deep"
w_o_h = 1.5
if False:
    palette = seaborn.color_palette(palette_name, 8)
    seaborn.palplot(palette)


def count_parameters(model) -> int:
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def do_model_size_analysis(dim: int) -> dict:
    paths = utils.path_config.get_paths()
    model_filename = f"model_size_analysis{dim}.pkl"
    model_filedir = paths["results"]
    model_fullfilename = os.path.join(model_filedir, model_filename)

    if os.path.exists(model_fullfilename):
        with open(model_fullfilename, 'rb') as fh:
            model_size_analysis = pickle.load(fh)
        logger.info(f"Loading from '{model_fullfilename}'")
    else:
        z = torch.zeros(1)
        s = tuple(range(2, dim))
        rsets = tools.powerset(s)
        if () in rsets:
            rsets.remove(())
        num_sets = len(rsets)

        parameter_counts = [None] * num_sets
        criterions = [None] * num_sets
        incl_d1 = [None] * num_sets
        ks_lens = [None] * num_sets

        # ks_sets = [[0, 1]] + [[0, 1] + list(_) for _ in rsets]
        ks_sets = [[0, 1] + list(_) for _ in rsets]
        # for idx, ks0 in enumerate(rsets):
        for idx, ks in enumerate(ks_sets):
            # idx = 0; ks0 = rsets[idx]
            # ks = [0, 1] + list(ks0)  # + [dim - 1]
            soln = depth_analysis.get_dk_situation(dim, ks)
            betas = torch.cat((soln["argmin"], z)).to(torch.float32)

            shallowest_network = networks.build_shallowest_network(betas, dim)
            shallowest_network_fused = networks.canonicalize_network(shallowest_network)
            parameter_counts[idx] = count_parameters(shallowest_network_fused)
            criterions[idx] = soln["criterion"]
            incl_d1[idx] = dim - 1 in ks
            ks_lens[idx] = len(ks)

        assert 2 ** (dim - 2) - 1 == len(criterions)
        model_size_analysis = {
            "parameter_counts": parameter_counts,
            "criterions": criterions,
            "incl_d1": incl_d1,
            "ks_lens": ks_lens,
            "dim": dim
        }
        with open(model_fullfilename, 'wb') as fh:
            pickle.dump(model_size_analysis, fh, protocol=pickle.HIGHEST_PROTOCOL)
        logger.info(f"Dumping to '{model_fullfilename}'")
    return model_size_analysis


def model_size_error_scatterplot(model_size_analysis: dict,
                                 figscl: float) -> FigAx:
    parameter_counts = torch.tensor(model_size_analysis["parameter_counts"])
    criterions = torch.tensor(model_size_analysis["criterions"])
    # dim = model_size_analysis["dim"]

    ks_lens = torch.tensor(model_size_analysis["ks_lens"])
    rset_sizes = sorted(set(ks_lens.tolist()))
    markersize = 2
    fig, axs = plt.subplots(1, 1,
                            figsize=(w_o_h * figscl, figscl),
                            squeeze=False)
    ax = axs[0, 0]
    palette = seaborn.color_palette(palette_name, len(rset_sizes))

    for idx, s in enumerate(rset_sizes):
        # d = min(ks_lens) + 1
        # idx = 0; s = rset_sizes[idx]
        # inds = [s == _ for _ in ks_lens]
        inds = s == ks_lens
        col = palette[idx]
        ax.semilogx(criterions[inds],
                    parameter_counts[inds], ".", color=col,
                    markersize=markersize)

        # ax.semilogy(parameter_counts[inds],
        #             criterions[inds], ".", color=col,
        #             markersize=markersize)
    # ax.semilogy(parameter_counts[inds] + 10_000_000,
    #             criterions[inds], ".", color=col,
    #             markersize=markersize)
    # ax.xaxis.offsetText.set_position((1, 10))

    # position = ax.xaxis.offsetText.get_position()
    # ax.xaxis.offsetText.set_position((2, 2))
    # ax.xaxis.offsetText.set_position((0, position[1] * 2))
    # ax.xaxis.offsetText.set_position((0, 0))
    # ax.xaxis.offsetText.set_visible(False)
    fig.tight_layout(pad=.2)
    return fig, axs


def model_index_size_plot(model_size_analysis: dict,
                          figscl: float) -> FigAx:
    dim = model_size_analysis["dim"]
    parameter_counts = model_size_analysis["parameter_counts"]
    ks_lens = model_size_analysis["ks_lens"]
    n = len(ks_lens)
    diff_inds = torch.tensor(ks_lens).diff()
    vs = (1 + torch.argwhere(diff_inds).flatten()).tolist()

    fig, axs = plt.subplots(1, 1,
                            figsize=(w_o_h * figscl, figscl),
                            squeeze=False)
    linewidth = .9
    ax = axs[0, 0]
    ax.plot(parameter_counts, linewidth=linewidth, color="k")

    ylim = ax.get_ylim()
    a = .1
    y_above = a * ylim[0] + (1 - a) * ylim[1]
    y_below = (1 - a) * ylim[0] + a * ylim[1]

    assert len(vs) == dim - 3
    palette = seaborn.color_palette(palette_name, dim - 3)

    for idx, v in enumerate(vs):
        # idx = 0; v = vs[idx]
        # idx = 1; v = vs[idx]
        # ax.axvline(v)
        if v <= 3 * n // 4:
            y = y_above
        else:
            y = y_below
        if 0 == idx:
            x0 = 0
        else:
            x0 = vs[idx - 1]
        x = (x0 + v) / 2
        col = palette[idx]
        # print(x0, col)
        xy = (x0, ylim[0])
        rect = matplotlib.patches.Rectangle(xy,
                                            v - x0,
                                            ylim[1] - ylim[0],
                                            linewidth=1,
                                            facecolor=col,
                                            alpha=.9)
        ax.add_patch(rect)
        if (idx > 0) and (idx < len(vs) - 1):
            ax.text(x, y, f"{ks_lens[vs[idx]]}",
                    color="k", ha='center')
    # ax.set_xlim(*xlim)
    ax.set_xlim(0, n - 1)
    # fig.tight_layout()
    fig.tight_layout(pad=.2)
    return fig, axs


def d1_plot(model_size_analysis: dict,
            figscl: float) -> FigAx:
    # dim = model_size_analysis["dim"]
    parameter_counts = torch.tensor(model_size_analysis["parameter_counts"])
    incl_d1 = torch.tensor(model_size_analysis["incl_d1"])
    criterions = torch.tensor(model_size_analysis['criterions'])

    n = incl_d1.sum()
    plot_x = torch.arange(n)
    plot_y1 = parameter_counts[incl_d1]
    plot_y2 = criterions[incl_d1]

    fig, axs = plt.subplots(1, 1,
                            figsize=(w_o_h * figscl, figscl),
                            squeeze=False)
    ax1 = axs[0, 0]
    ax2 = ax1.twinx()
    ax1.plot(plot_x, plot_y1, "k")
    ax2.semilogy(plot_x, plot_y2, "k.")
    return fig, axs


def versus_dim_plot(model_size_analysis: dict,
                    figscl: float) -> FigAx:
    z = torch.zeros(1)

    min_dim = 5
    max_dim = 12
    dims = list(range(min_dim, max_dim))
    num_dims = len(dims)

    parameter_counts = [None] * num_dims
    criterions = [None] * num_dims

    for idx, dim in enumerate(dims):
        # idx = 0; dim = dims[idx]
        ks = list(range(dim))
        soln = depth_analysis.get_dk_situation(dim, ks)
        betas = torch.cat((soln["argmin"], z)).to(torch.float32)

        logger.info(f"building shallowest network for d = {dim}")
        shallowest_network = networks.build_shallowest_network(betas, dim)
        shallowest_network_fused = networks.canonicalize_network(shallowest_network)

        depth = len(shallowest_network_fused) // 2
        # assert math.ceil(math.log2(dim)) == depth
        widths = [shallowest_network_fused[2 * i].out_features for i in range(depth)]

        parameter_counts[idx] = count_parameters(shallowest_network_fused)
        criterions[idx] = soln["criterion"]

    fig, axs = plt.subplots(1, 1,
                            figsize=(w_o_h * figscl, figscl),
                            squeeze=False)
    ax1 = axs[0, 0]
    # ax1.plot(dims, criterions)
    ax1.semilogy(dims, criterions)
    ax2 = ax1.twinx()
    # ax2.plot(dims, parameter_counts)
    ax2.semilogy(dims, parameter_counts)
    if False:
        t = torch.tensor([(2 ** _) ** 2 for _ in dims])
        s = torch.tensor(parameter_counts)
    return fig, axs


def plot_model_size_analysis(model_size_analysis: dict,
                             figscl: float) -> Tuple[list, list]:
    dim = model_size_analysis["dim"]

    all_figs = []
    all_idents = []

    fig, axs = d1_plot(model_size_analysis, figscl)
    all_figs += [fig]
    all_idents += [f"d1_plot_{dim}"]

    fig, axs = model_index_size_plot(model_size_analysis, figscl)
    all_figs += [fig]
    all_idents += [f"model_index_size_{dim}"]

    fig, axs = model_size_error_scatterplot(model_size_analysis, figscl)
    all_figs += [fig]
    all_idents += [f"model_size_error_{dim}"]

    return all_figs, all_idents


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--dim", type=int, default=9)
    parser.add_argument("--fig_format", type=str, default="pgf")
    # parser.add_argument("--fig_scl", type=float, default=3.25)
    parser.add_argument("--fig_scl", type=float, default=1.6)

    args = parser.parse_args()
    # dim = 10
    # dim = 9
    dim = args.dim
    fig_format = args.fig_format
    figscl = args.fig_scl

    paths = utils.path_config.get_paths()
    # filepath = os.path.join(paths["plots"], "statistical_sorting.png")
    filepath = paths["plots"]
    ident = f"num_parameters_dim{dim}"
    model_size_analysis = do_model_size_analysis(dim)
    # fig_format = "pdf"

    if "pgf" == fig_format:
        texsystem = "pdflatex"
        font_family = "serif"
        utils.plotting.initialise_pgf_plots(texsystem, font_family)

    # f = versus_dim_plot(model_size_analysis, figscl)
    all_figs, all_idents = plot_model_size_analysis(model_size_analysis, figscl)

    # fig, axs = model_size_scatterplot(model_size_analysis, figscl)
    for fig, ident in zip(all_figs, all_idents):
        fig_path = utils.plotting.smart_save_fig(fig, ident, fig_format, filepath)
        logger.info(f"Saving {fig_path}")
    # print(f"Check '{fig_path}' [size = {fig.get_size_inches()}]")
    # if False:
    #     print(f"Here's the starting point for a caption:")
    #     print(print_caption())
