import logging
import functools
import os
import time
import warnings
import multiprocessing
from typing import Any, Callable, Dict, List, Tuple, Union
import itertools

import sklearn.metrics
import sklearn.datasets
import numpy as np
import scipy.linalg
import seaborn
import pandas as pd

import matplotlib
import matplotlib.pyplot as plt

import torch
import torchvision

import plotting
import tools
import data
import inversion
import pytorch_models

logging_format = " {%(pathname)s:%(lineno)d} %(asctime)s: %(message)s"
# logging_level = logging.INFO
# logging_level = logging.DEBUG
logging_level = 15
logging.basicConfig(level=logging_level,
                    format=logging_format)

logger = logging.getLogger(__name__)


def _classify(x: np.ndarray,
              h_ineq: np.ndarray,
              model: pytorch_models.Net) -> np.ndarray:
    x_torch = torch.from_numpy(x).type(torch.FloatTensor)
    prediction_torch = model(x_torch)
    prediction = prediction_torch.detach().numpy()

    nr = prediction.shape[0]
    to_append = np.ones((nr, 1))
    prediction_with_intercept = np.hstack((to_append, prediction))
    threshold_values = (h_ineq @ prediction_with_intercept.T).T
    classification = np.all(threshold_values > 0, axis=1)
    return classification


def get_statistics_str(statistics: Dict[str, np.ndarray]) -> str:
    accuracy = statistics["accuracy"]
    precision = statistics["precision"]
    recall = statistics["recall"]
    num = statistics["num"]

    statistics_list = []
    num_classes = len(accuracy)
    fmt = "Class #{} ({:.0f} observations):\n  accuracy: {:.3f}\n  precision: {:.3f}\n  recall: {:.3f}"
    for idx in range(num_classes):
        a = accuracy[idx]
        p = precision[idx]
        r = recall[idx]
        n = num[idx]
        statistics_list += [fmt.format(idx, n, a, p, r)]
    statistics_str = "\n".join(statistics_list)
    return statistics_str


def _compute_statistics(y: np.ndarray,
                        x: np.ndarray,
                        classifiers: List[Callable],
                        invert_classes: List[int]) -> Dict[str, np.ndarray]:
    num_classes = len(classifiers)

    accuracy = np.full((num_classes,), np.nan)
    precision = np.full((num_classes,), np.nan)
    recall = np.full((num_classes,), np.nan)
    num = np.full((num_classes,), np.nan)

    for idx, classifier in enumerate(classifiers):
        # idx = 0; classifier = classifiers[idx]
        fitted_y = classifier(x)
        idx_class = invert_classes[idx]
        is_actually = (idx_class == y).flatten()
        n = np.sum(is_actually)
        is_fitted = fitted_y.flatten()

        a = sklearn.metrics.accuracy_score(is_actually, is_fitted)
        p = sklearn.metrics.precision_score(is_actually, is_fitted)
        r = sklearn.metrics.recall_score(is_actually, is_fitted)

        num[idx] = n
        accuracy[idx] = a
        precision[idx] = p
        recall[idx] = r

    grand_mean_accuracy = np.sum(num * accuracy) / np.sum(num)
    grand_mean_precision = np.sum(num * precision) / np.sum(num)
    grand_mean_recall = np.sum(num * recall) / np.sum(num)

    statistics = {
        "accuracy": accuracy,
        "grand_mean_accuracy": grand_mean_accuracy,
        "grand_mean_precision": grand_mean_precision,
        "grand_mean_recall": grand_mean_recall,
        "num": num,
        "precision": precision,
        "recall": recall
    }
    return statistics


def get_xy_from_dataset(dataset: torch.utils.data.dataset) -> Tuple[np.ndarray, np.ndarray]:
    dataset_type = type(dataset)
    if dataset_type == torchvision.datasets.mnist.MNIST:
        x_torch = dataset.data
        y_torch = dataset.targets
    elif dataset_type == torch.utils.data.TensorDataset:
        x_torch = dataset.tensors[0]
        y_torch = dataset.tensors[1]
    else:
        raise ValueError("Please tell me how to extract x and y from a ".format(dataset_type))

    x = x_torch.detach().numpy()
    y = y_torch.detach().numpy()
    return x, y


def np_to_torch(x: np.ndarray) -> torch.Tensor:
    return torch.from_numpy(x).type(torch.FloatTensor)


def torch_to_np(x: torch.Tensor) -> np.ndarray:
    return x.detach().numpy()


def compute_hull_volume(vertices: np.ndarray) -> float:
    num_vertices = vertices.shape[0]
    if num_vertices < 3:
        hull_volume = 0
    else:
        hull = scipy.spatial.ConvexHull(vertices)
        hull_volume = hull.volume
    return hull_volume


def analyze_volume_by_class(preimage_decomps: List[list]) -> Dict[str, Any]:
    volume_by_class = dict()
    for class_idx, class_preimage_decomp in enumerate(preimage_decomps):
        decomp_len = len(class_preimage_decomp)
        class_volumes = np.zeros((decomp_len,))
        for row_idx, row in enumerate(class_preimage_decomp):
            # row_idx = 0; row = class_preimage_decomp[row_idx]
            # row_idx = 7963; row = class_preimage_decomp[row_idx]

            if row["v"]["is_empty"]:
                continue
            vertices = row["v"]["vertices"]
            assert np.all(1 == vertices[:, 0])
            class_volumes[row_idx] = compute_hull_volume(vertices[:, 1:])

        volume_by_class[class_idx] = class_volumes
    return volume_by_class


def analyze_data_by_class(train_x: np.ndarray,
                          preimage_decomps: List[list]) -> Dict[str, Any]:
    num_rows = train_x.shape[0]
    to_hcat = np.ones((num_rows, 1))

    tol_inequality = 1e-10
    tol_linear = 1e-10

    flat_trainx = np.reshape(train_x, (num_rows, -1))
    train_x1 = np.hstack((to_hcat, flat_trainx))
    num_classes = len(preimage_decomps)

    data_by_class = dict()

    for class_idx in range(num_classes):
        # class_idx = 0; idx_preimage_decomp = preimage_decomps[class_idx]
        idx_preimage_decomp = preimage_decomps[class_idx]
        num_polytopes = len(idx_preimage_decomp)

        in_out_mat = np.full((num_rows, num_polytopes), False)

        for polytope_idx, polytope in enumerate(idx_preimage_decomp):
            # print(polytope_idx)
            inequality = polytope["h"]["inequality"]
            linear = polytope["h"]["linear"]

            threshs_inequality = (inequality @ train_x1.T).T
            threshs_linear = (linear @ train_x1.T).T

            in_inequality = np.all(threshs_inequality >= -1 * tol_inequality, axis=1)
            in_linear = np.all(np.abs(threshs_linear) <= tol_linear, axis=1)

            is_in = in_inequality & in_linear
            in_out_mat[is_in, polytope_idx] = True

        inclusion_per_polytope = np.sum(in_out_mat, axis=0)
        data_by_class[class_idx] = inclusion_per_polytope

    return data_by_class


def compute_results_summary(decomps: list,
                            model: pytorch_models.Net,
                            train_x: np.ndarray,
                            train_y: np.ndarray,
                            test_x: np.ndarray,
                            test_y: np.ndarray,
                            par: Dict[str, Any]) -> Dict[str, Any]:
    results_par = par["results"]
    inversion_par = par["inversion"]

    invert_classes = inversion_par["invert_classes"]

    do_volume_analysis = False
    num_classes = len(invert_classes)
    classifiers = [None] * num_classes  # Type: List[Callable]

    for class_idx in range(num_classes):
        # class_idx = 0
        terminal_polytope = decomps[class_idx][-1]
        h_ineq = terminal_polytope[0]["h"]["inequality"]

        classifier = functools.partial(_classify,
                                       h_ineq=h_ineq,
                                       model=model)
        classifiers[class_idx] = classifier

    train_statistics = _compute_statistics(train_y, train_x, classifiers, invert_classes)
    test_statistics = _compute_statistics(test_y, test_x, classifiers, invert_classes)

    num_layers = len(model.layers)
    num_nonempty = np.full((num_classes, num_layers), np.nan)

    count_v = True
    for class_idx in range(num_classes):
        # class_idx = 0
        class_decomps = decomps[class_idx]
        for layer_idx in range(num_layers):
            # layer_idx = 0
            class_layer_decomp = class_decomps[layer_idx]
            # num_in_decomp = len(class_layer_decomp)
            # for idx in range(len(class_layer_decomp)):
            for idx, p in enumerate(class_layer_decomp):
                # p =
                # idx = 0
                v = class_layer_decomp[idx]["v"]
                if v["vertices"] is not None:
                    is_empty = v["vertices"].size == 0
                else:
                    is_empty = False
                class_layer_decomp[idx]["v"]["is_empty"] = is_empty
            if count_v:
                is_empty_list = [_["v"]["is_empty"] for _ in class_layer_decomp]
                num = len(is_empty_list) - np.sum(is_empty_list)
            else:
                num = len(class_layer_decomp)
            num_nonempty[class_idx, layer_idx] = num

    preimage_decomps = [d[0] for d in decomps]
    terminal_decomps = [d[-1] for d in decomps]

    if do_volume_analysis:
        volume_by_class = analyze_volume_by_class(preimage_decomps)
        all_volumes = np.concatenate([v for v in volume_by_class.values()])
        nonzero_volumes = all_volumes[all_volumes > 0]
    else:
        volume_by_class = None
        nonzero_volumes = None

    data_by_class = analyze_data_by_class(train_x, preimage_decomps)
    all_datas = np.concatenate([v for v in data_by_class.values()])
    nonzero_datas = all_datas[all_datas > 0]
    input_layer_bounds = par["inversion"]["input_layer_bounds"]

    results_summary = {
        "classifiers": classifiers,
        "data_by_class": data_by_class,
        "input_layer_bounds": input_layer_bounds,
        "model": model,
        "nonzero_datas": nonzero_datas,
        "nonzero_volumes": nonzero_volumes,
        "num_nonempty": num_nonempty,
        "par": par,
        "preimage_decomps": preimage_decomps,
        "train_statistics": train_statistics,
        "test_statistics": test_statistics,
        "terminal_decomps": terminal_decomps,
        "volume_by_class": volume_by_class
    }
    if False:
        results_summary["model"]
    return results_summary


def _terminal_decomp_plot(fig: matplotlib.figure.Figure,
                          axs: np.ndarray,
                          terminal_decomps: List[list],
                          train_x: np.ndarray,
                          model) -> Tuple[matplotlib.figure.Figure, np.array]:
    num_classes = len(terminal_decomps)
    v_forms = [None] * num_classes
    for idx in range(num_classes):
        # idx = 0
        terminal_decomp = terminal_decomps[idx]
        v_forms[idx] = terminal_decomp[0]["v"]["vertices"]

    logits = model(np_to_torch(train_x))
    logits_np = logits.detach().numpy()

    logits_argmax = np.argmax(logits_np, axis=1)
    logits_np1 = np.hstack((np.ones((train_x.shape[0], 1)), logits_np))

    ax = axs[0, 0]
    alpha = .30

    xl, yl = plotting.get_plotlims_from_pointcloud(logits_np)
    is_class = np.full((train_x.shape[0], num_classes), False)
    num_v_forms = len(v_forms)
    colors = plotting.get_palette_of_length(num_v_forms)
    for idx, v_form in enumerate(v_forms):
        # idx = 0; v_form = v_forms[idx]
        v_lin = np.empty((0, v_form.shape[0]))
        h_form = tools.v_to_h(v_form, v_lin, True)
        h_form = tools.canonicalize_h_form(h_form)

        thresh_value = logits_np1 @ h_form.T

        # logger.info("Plotting v form #{}".format(idx))
        color = colors[idx]
        plotting.convex_hull_plot(ax,
                                  v_form,
                                  xl,
                                  yl, color, alpha)
        rows = (logits_argmax == idx)
        ax.scatter(logits_np[rows, 0],
                   logits_np[rows, 1], color=color)
        is_class[:, idx] = (thresh_value > 0).flatten()

    return fig, axs


def safe_int(n: float) -> int:
    intn = int(n)
    assert intn == n
    return intn


def _cast_to_square(x: np.ndarray) -> np.ndarray:
    dim2 = x.size
    dim = safe_int(dim2 ** .5)
    return np.reshape(x, (dim, dim))


def _do_image_plot(fig: matplotlib.figure.Figure,
                   axs: np.array,
                   invert_classes: List[int],
                   x: np.ndarray,
                   y: np.ndarray) -> Tuple[matplotlib.figure.Figure, np.array]:
    num_classes = len(invert_classes)

    for idx in range(num_classes):
        # idx = 0
        invert_class = invert_classes[idx]
        row = np.argmax(y == invert_class)
        rowdata = x[row, :]
        # if len(rowdata.shape) > 2:
        #     to_plot = np.swapaxes(rowdata, axis1=0, axis2=2)
        # else:
        #     to_plot = rowdata

        assert 1 == len(rowdata.shape)
        to_plot = _cast_to_square(rowdata)

        axs[0, idx].imshow(to_plot)
        axs[0, idx].set_title(idx)
    return fig, axs


def positive_entries(x: np.ndarray) -> np.ndarray:
    x_flat = x.flatten()
    return x_flat[x_flat > 0]


def initialize_plot() -> Tuple[matplotlib.figure.Figure, np.array]:
    pass


def _plot_saver(fig: matplotlib.figure.Figure,
                ident: str,
                plot_format: str,
                plots_dir: str) -> None:
    plotting.initialise_pgf_plots("pdflatex", "serif")
    fullfilename = plotting.smart_save_fig(fig,
                                           ident,
                                           plot_format,
                                           plots_dir)
    plotting.finalize_plot()
    logger.info("Saved {}".format(fullfilename))


def preimage_plot_2d(preimage_decomps: List[list],
                     train_x: np.ndarray,
                     xlims: Tuple[float, float],
                     ylims: Tuple[float, float],
                     big_plot_scale: float) -> Tuple[matplotlib.figure.Figure, np.array]:
    num_preimages = len(preimage_decomps)

    vlists = [None] * num_preimages
    for idx in range(num_preimages):
        # idx = 0
        preimage_decomp = preimage_decomps[idx]
        vlist = [_["v"]["vertices"] for _ in preimage_decomp]
        vlists[idx] = vlist
    fig, ax = plotting.plot_decomp(train_x,
                                   xlims,
                                   ylims,
                                   vlists,
                                   big_plot_scale)
    return fig, ax


def present_results_summary(results_summary: Dict[str, Any],
                            train_x: np.ndarray,
                            train_y: np.ndarray,
                            par: Dict[str, Any]) -> None:
    model = results_summary["model"]

    preimage_decomps = results_summary["preimage_decomps"]
    terminal_decomps = results_summary["terminal_decomps"]
    classifiers = results_summary["classifiers"]
    train_statistics = results_summary["train_statistics"]
    test_statistics = results_summary["test_statistics"]
    num_nonempty = results_summary["num_nonempty"]
    # volume_by_class = results_summary['volume_by_class']
    # data_by_class = results_summary["data_by_class"]
    nonzero_volumes = results_summary["nonzero_volumes"]
    nonzero_datas = results_summary["nonzero_datas"]

    # num_classes = len(classifiers)

    results_par = par["results"]
    inversion_par = par["inversion"]

    invert_classes = inversion_par["invert_classes"]

    grid_color = "grey"
    do_terminal_decomp_plot = results_par["do_terminal_decomp_plot"]
    do_2d_plot = results_par['do_2d_plot']
    is_image = results_par['is_image']
    save_plots = results_par["save_plots"]
    plots_dir = results_par["plots_dir"]
    plot_scale = results_par["plot_scale"]
    big_plot_scale = results_par["big_plot_scale"]

    plot_format = results_par["plot_format"]

    logger.info("ReLU decomposition info")
    layers = results_summary["model"].layers

    num_nonempty_int = num_nonempty.astype(int)

    row_names = ["#{:02d}: {}".format(layer_idx, str(layer))
                 for layer_idx, layer in enumerate(layers)]
    col_names = invert_classes

    __ = pd.DataFrame(num_nonempty_int.T,
                      index=row_names,
                      columns=col_names)
    with pd.option_context('display.max_colwidth', -1):
        to_log = __.to_string()
    logger.info("Nonempty polytopes per layer\n" + to_log)

    train_statistics_str = get_statistics_str(train_statistics)
    test_statistics_str = get_statistics_str(test_statistics)

    logger.info("\n=== Train statistics ===\n" + train_statistics_str)
    logger.info("\n=== Test statistics ===\n" + test_statistics_str)

    do_volume_distribution_plot = True
    if do_volume_distribution_plot and (nonzero_volumes is not None):
        sorted_nonzero_volumes = sorted(nonzero_volumes)
        cumdist_volume = np.cumsum(sorted_nonzero_volumes) / np.sum(nonzero_volumes)

        fig, axs = plotting.wrapped_subplot(1, 1, plot_scale)
        axs[0, 0].plot(cumdist_volume)
        axs[0, 0].grid(which="major", color=grid_color)

        if save_plots:
            ident = "polytope_volume_distribution"
            _plot_saver(fig, ident, plot_format, plots_dir)

    do_data_inclusion_plot = False
    if do_data_inclusion_plot:
        sorted_nonzero_datas = sorted(nonzero_datas)
        cumdist_data = np.cumsum(sorted_nonzero_datas) / np.sum(nonzero_datas)

        fig, axs = plotting.wrapped_subplot(1, 1, plot_scale)
        axs[0, 0].plot(cumdist_data)
        axs[0, 0].grid(which="major", color=grid_color)

        if save_plots:
            ident = "polytope_data_distribution"
            _plot_saver(fig, ident, plot_format, plots_dir)

    if do_2d_plot:
        include_contours = False
        include_axes = False
        fig, ax = plotting.bruteforced_prob_contour_plot(train_x,
                                                         train_y,
                                                         model,
                                                         include_contours,
                                                         include_axes,
                                                         plot_scale)

        do_tight_layout = True
        if do_tight_layout:
            fig.tight_layout()

        if save_plots:
            ident = "bruteforced_prob_contour_plot_false"
            _plot_saver(fig, ident, plot_format, plots_dir)

        include_contours = True
        include_axes = True
        fig, ax = plotting.bruteforced_prob_contour_plot(train_x,
                                                         train_y,
                                                         model,
                                                         include_contours,
                                                         include_axes,
                                                         big_plot_scale)
        if save_plots:
            ident = "bruteforced_prob_contour_plot_true"
            _plot_saver(fig, ident, plot_format, plots_dir)

        fig, ax = plotting.bruteforced_decision_boundary_plot(train_x,
                                                              train_y,
                                                              classifiers,
                                                              big_plot_scale)
        if save_plots:
            ident = "bruteforced_decision_boundary"
            _plot_saver(fig, ident, plot_format, plots_dir)

    if do_terminal_decomp_plot:
        fig, axs = plotting.wrapped_subplot(1, 1, plot_scale)
        _terminal_decomp_plot(fig, axs, terminal_decomps, train_x, model)
        if save_plots:
            ident = "terminal_decomposition_plot"
            _plot_saver(fig, ident, plot_format, plots_dir)

    if do_2d_plot and inversion_par["need_initial_v"]:
        input_layer_bounds = results_summary["input_layer_bounds"]

        lower = input_layer_bounds[0].flatten()
        upper = input_layer_bounds[1].flatten()

        xlims = (lower[0], upper[0])
        ylims = (lower[1], upper[1])

        fig, ax = preimage_plot_2d(preimage_decomps,
                                   train_x,
                                   xlims,
                                   ylims,
                                   plot_scale)
        if save_plots:
            ident = "2d_inversion_plot"
            _plot_saver(fig, ident, plot_format, plots_dir)

    if is_image:
        num_classes = len(invert_classes)
        fig, axs = plotting.wrapped_subplot(1, num_classes, plot_scale)
        fig, axs = _do_image_plot(fig, axs, invert_classes, train_x, train_y)
