import os
import logging
import time
import functools
import datetime as dt
from typing import Any, Dict, List, Tuple
from collections import Counter

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import tqdm
import numba
# from numba import jit

import caching
import path_config
import tools
import plotting
import decomp
import draw_airplane

import acas.nnet

Polytope = Dict[str, Any]
HRepresentation = Dict[str, Any]
VRepresentation = Dict[str, Any]
Region = List[Polytope]


np.set_printoptions(linewidth=1000)

logging_format = "%(asctime)s: %(message)s"
# logging_format = "{%(func)s: %(lineno)4s: {%(asctime)s: %(message)s"

logging_level = 15
logging.basicConfig(level=logging_level,
                    format=logging_format)

logger = logging.getLogger(__name__)

coc = (.9, .9, .9)  # white
# coc = (1., 1., 1.)  # white
sr = (.0, .0, .5)  # navy
sl = (.0, .6, .0)  # green
wr = (.5, .5, .5)  # grey
wl = (.7, .9, .0)  # neon green
colors = [coc, wl, wr, sl, sr]


def interleave_lists(l1: List[Any],
                     l2: List[Any]) -> List[Any]:
    # https://codegolf.stackexchange.com/questions/169893/python-shortest-way-to-interleave-items-from-two-lists
    return [*sum(zip(l1, l2), ())]


def _matrix_rank_empty(m: np.ndarray) -> int:
    if 0 == m.shape[0]:
        mre = 0
    else:
        mre = np.linalg.matrix_rank(m)
    return mre


def _build_layer_inverse_args(idx: int,
                              dim: int) -> Dict[str, Any]:
    # layer_inverse_args = {}
    is_rational = False
    # need_v = True
    # need_v = (idx == 0)
    need_v = False
    if dim is None:
        layer_limit = None
    else:
        layer_limit = (np.full((dim, 1), -1 * np.inf),
                       np.full((dim, 1), +1 * np.inf))

    layer_inverse_args = {
        "is_rational": is_rational,
        "need_v": need_v,
        "layer_limit": layer_limit
    }
    return layer_inverse_args


def _add_preimage(ax: matplotlib.axes.Axes,
                  preimage: Region,
                  lower: np.ndarray,
                  upper: np.ndarray,
                  means: np.ndarray,
                  ranges: np.ndarray,
                  color: plotting.Color):
    alpha = 1.0
    w = np.diag(ranges[:-1, :].flatten())
    b = means[:-1]

    maxs = w @ upper[:-1, :] + b
    mins = w @ lower[:-1, :] + b

    xlim = (mins[0].item(), maxs[0].item())
    ylim = (mins[1].item(), maxs[1].item())

    for idx, p in enumerate(preimage):
        # idx = 0; p = preimage[idx]
        h = p["h"]
        if h["is_empty"]:
            continue
        h_ineq = h["inequality"]
        h_lin = h["linear"]
        v_repr = tools.h_to_v(h_ineq, h_lin)
        rnk = _matrix_rank_empty(v_repr)

        usable = rnk > 0
        if usable:
            plot_v_repr = tools.apply_linear_transformation_to_v_repr(v_repr, w, b)
            plotting.convex_hull_plot(ax,
                                      plot_v_repr,
                                      xlim,
                                      ylim,
                                      color,
                                      alpha)


def _outer2(x_coords: np.ndarray, y_coords: np.ndarray) -> np.ndarray:
    eps_ball = np.meshgrid(x_coords, y_coords)
    res = np.column_stack((eps_ball[0].flatten(), eps_ball[1].flatten()))
    return res


def _outer3(x_coords: np.ndarray,
            y_coords: np.ndarray,
            z_coords: np.ndarray) -> np.ndarray:
    eps_ball = np.meshgrid(x_coords, y_coords, z_coords)
    res = np.column_stack((eps_ball[0].flatten(),
                           eps_ball[1].flatten(),
                           eps_ball[2].flatten()))
    return res


def partialize_network_last_input(ws: List[np.ndarray],
                                  bs: List[np.ndarray],
                                  partialize_at: float,
                                  input_transformation: np.ndarray,
                                  output_transformation: np.ndarray) -> Tuple[list, list]:
    num_linear_layers = len(ws)
    assert len(bs) == num_linear_layers
    assert num_linear_layers > 1

    w0 = ws[0]
    b0 = bs[0]

    w0_transformed = w0 @ input_transformation

    wf = ws[-1]
    bf = bs[-1]

    new_w0 = w0_transformed[:, :-1]
    new_b0 = b0 + np.vstack(w0_transformed[:, -1]) * partialize_at

    new_wf = output_transformation @ wf
    new_bf = output_transformation @ bf

    new_ws = [new_w0] + ws[1:-1] + [new_wf]
    new_bs = [new_b0] + bs[1:-1] + [new_bf]
    return new_ws, new_bs


def _mylogfmt(x, _):
    filter = '${:+.0f}$'
    if 0 == x:
        filter = '${:.0f}$'
    return filter.format(x / 1000)


def encounter_plot(fig: matplotlib.figure.Figure,
                   ax: matplotlib.axes.Axes,
                   preimages: List[np.ndarray],
                   lower: np.ndarray,
                   upper: np.ndarray,
                   means: np.ndarray,
                   ranges: np.ndarray,
                   xlims: Tuple[float, float],
                   ylims: Tuple[float, float],
                   intr_sign: int):
    y_scale = ylims[1] - ylims[0]
    airplane_scale = .05 * y_scale

    for idx, preimage in enumerate(preimages):
        # idx = 1; preimage = preimages[idx]
        color = colors[idx]
        _add_preimage(ax,
                      preimage,
                      lower,
                      upper,
                      means,
                      ranges,
                      color)
    ax.set_xlim(*xlims)
    ax.set_ylim(*ylims)

    own_center = np.zeros((1, 2))
    intr_center = .875 * np.array([[xlims[1], intr_sign * ylims[1]]])

    draw_airplane.add_airplane_at(ax,
                                  direction=0,
                                  center=own_center,
                                  scale=airplane_scale,
                                  color="k")
    draw_airplane.add_airplane_at(ax,
                                  direction=partialize_at,
                                  center=intr_center,
                                  scale=airplane_scale,
                                  color="r")
    # add_legend = False
    add_legend = True
    if add_legend:
        # https://stackoverflow.com/questions/42435446/how-to-put-text-outside-python-plots
        fig.subplots_adjust(right=0.8)

        legend_elements = [matplotlib.patches.Patch(facecolor=coc, label='COC'),
                           matplotlib.patches.Patch(facecolor=sr, label='SR'),
                           matplotlib.patches.Patch(facecolor=sl, label='SL'),
                           matplotlib.patches.Patch(facecolor=wr, label='WR'),
                           matplotlib.patches.Patch(facecolor=wl, label='WL'),
                           ]
        # https://stackoverflow.com/questions/41827406/matplotlib-how-to-adjust-space-between-legend-markers-and-labels
        textpad = -0.8
        leg = ax.legend(handles=legend_elements,
                        loc='center left',
                        bbox_to_anchor=(1.01, .5),
                        handletextpad=textpad)

        patch_size = 7.0
        for p in leg.get_patches():
            p.set_width(patch_size)
            p.set_height(patch_size)
        fig.subplots_adjust(right=0.8)

    formatter = matplotlib.ticker.FuncFormatter(_mylogfmt)

    ax.yaxis.set_major_formatter(formatter)
    ax.xaxis.set_major_formatter(formatter)
    fig.tight_layout()


def bruteforced_encounter_plot(fig,
                               ax,
                               ws,
                               bs,
                               lower,
                               upper,
                               means,
                               ranges,
                               xlims,
                               ylims,
                               intr_sign):
    num_x = 100
    num_y = 100

    xs = np.linspace(xlims[0], xlims[1], num_x)
    ys = np.linspace(ylims[0], ylims[1], num_y)

    mean_xy = means[:2].flatten()
    range_xy = ranges[:2].flatten()

    inputs_xy = _outer2(xs, ys)
    inputs_xy_normalized = (inputs_xy - mean_xy) / range_xy

    ws0 = ws[0]
    bs0 = bs[0]

    ws1 = ws[1]
    bs1 = bs[1]

    z0 = inputs_xy_normalized
    z1 = tools.relu((ws0 @ z0.T + bs0).T)
    z2 = (ws1 @ z1.T + bs1).T

    resps_argmax = np.argmax(z2, axis=1)
    for idx in range(5):
        # idx = 0
        # idx = 2
        rows = (idx == resps_argmax)
        ax.scatter(inputs_xy[rows, 0],
                   inputs_xy[rows, 1], color=colors[idx])
    ax.set_xlim(*xlims)
    ax.set_ylim(*ylims)


def invert_relunet_fromwb(ws, bs, input_layer_bounds):
    in_dim = ws[0].shape[1]
    out_dim = ws[-1].shape[0]

    num_linear_layers = len(ws)
    assert 1 + num_relu_layers == num_linear_layers

    linear_layer_coefficients = [{"w": ws[idx], "b": bs[idx]} for idx in range(num_linear_layers)]
    relu_layer_coefficients = [dict()] * num_relu_layers

    coefficients = interleave_lists(linear_layer_coefficients[:-1],
                                    relu_layer_coefficients) + [linear_layer_coefficients[-1]]

    types = ["Linear", "FlatRelu"] * num_relu_layers + ["Linear"]

    dimensions = [{"in_features": in_dim, "out_features": dim}] + \
                 [{"in_features": dim, "out_features": dim}] * (2 * num_relu_layers - 1) + \
                 [{"in_features": dim, "out_features": out_dim}]

    args = [_build_layer_inverse_args(idx, d["in_features"]) for idx, d in enumerate(dimensions)]
    args[0]["layer_limit"] = input_layer_bounds
    args[0]["need_v"] = True

    num_layers = num_linear_layers + num_relu_layers
    assert len(types) == num_layers
    assert len(dimensions) == num_layers
    assert len(coefficients) == num_layers

    layer_info = {'args': args,
                  'types': types,
                  'dimensions': dimensions,
                  'coefficients': coefficients}

    is_rational = False
    # desired_margin = -.1
    desired_margin = 0.0

    invert_classes = list(range(out_dim))
    cache_inversion = True

    # do_inversion = False
    do_inversion = True
    if do_inversion:
        inversion_par = {
            "cache_inversion": cache_inversion,
            "desired_margin": desired_margin,
            "input_layer_bounds": input_layer_bounds,
            "invert_classes": invert_classes,
            "is_rational": is_rational
        }

        logger.info("Starting decomps")
        out_dim = layer_info["dimensions"][-1]["out_features"]
        images_to_invert = decomp.build_class_preimages(out_dim, inversion_par)

        calc_fun = decomp.compute_decomps
        calc_args = (layer_info, images_to_invert)
        calc_kwargs = {}
        force_regeneraton = False
        do_caching = True

        if do_caching:
            cache_dir = paths['cached_calculations']
            decomps = caching.cached_calc(cache_dir,
                                          calc_fun,
                                          calc_args,
                                          calc_kwargs,
                                          force_regeneraton)
        else:
            decomps = decomp.compute_decomps(layer_info, images_to_invert)
    return decomps


# @numba.jit() # nopython=True)
def gen_ind_matrices(nx: int, ny: int, nz: int) -> Tuple[np.ndarray, np.ndarray]:
    total_cubes = (nx - 1) * (ny - 1) * (nz - 1)
    # total_cubes = nx * ny * nz
    dim = 3
    num_corners = 2 ** dim

    offsets = np.array([(ny - 1) * (nz - 1), nz - 1, 1])
    offsets_full = np.array([ny * nz, nz, 1])
    #
    ind_matrix = np.empty((total_cubes, dim, num_corners), dtype=np.int)
    corresp_rows = np.empty((total_cubes, num_corners), dtype=np.int)
    # ind_matrix = np.full((total_cubes, dim, num_corners), np.nan)
    # corresp_rows = np.full((total_cubes, num_corners), np.nan)

    for ix in range(1, nx):
        # ix = 1
        for iy in range(1, ny):
            # iy = 1
            for iz in range(1, nz):
                # iz = 1
                corners = np.array([[ix - 1, iy - 1, iz - 1],
                                    [ix - 1, iy - 1, iz - 0],
                                    [ix - 1, iy - 0, iz - 1],
                                    [ix - 1, iy - 0, iz - 0],
                                    [ix - 0, iy - 1, iz - 1],
                                    [ix - 0, iy - 1, iz - 0],
                                    [ix - 0, iy - 0, iz - 1],
                                    [ix - 0, iy - 0, iz - 0]])
                # row = (ix - 1) * ny * nz + (iy - 1) * nz + (iz - 1)
                linear_ind = corners[0, :] @ offsets
                ind_matrix[linear_ind, :, :] = corners.T
                corresp_rows[linear_ind, :] = corners @ offsets_full

    return ind_matrix, corresp_rows


# @numba.jit(nopython=True)
# @functools.lru_cache(32)
def compute_frac_of_cells_straddling_decision_boundary(x_grid: np.ndarray,
                                                       y_grid: np.ndarray,
                                                       z_grid: np.ndarray) -> float:
    nx = x_grid.shape[0]
    ny = y_grid.shape[0]
    nz = z_grid.shape[0]
    # ind_matrix = gen_ind_matrix(nx - 1, ny - 1, nz - 1)

    x_inds = np.arange(nx)
    y_inds = np.arange(ny)
    z_inds = np.arange(nz)

    grid = _outer3(x_grid, y_grid, z_grid)
    ind_grid = _outer3(x_inds, y_inds, z_inds)
    evaluated = n.evaluate_network_multiple(grid)
    ind_matrix, corresp_rows = gen_ind_matrices(nx, ny, nz)

    tnn = evaluated[tuple(corresp_rows), :]
    tnnam = np.argmax(tnn, axis=2)
    straddles_decision_boundary_cell = np.std(tnnam, axis=1) > 0
    frac_of_cells_straddling_decision_boundary = np.mean(straddles_decision_boundary_cell)
    return frac_of_cells_straddling_decision_boundary


if __name__ == "__main__":
    fig_format = "pgf"
    # fig_format = "png"

    paths = path_config.get_paths()
    hcas_root = paths["acas"]

    networks_dir = os.path.join(hcas_root, "networks")
    training_data_dir = os.path.join(hcas_root, "TrainingData")

    num_relu_layers = 1
    # num_relu_layers = 2
    # num_relu_layers = 3

    # neurons_per_layer = 6
    # neurons_per_layer = 8
    # neurons_per_layer = 16

    neurons_per_layer = 16
    # neurons_per_layer = 20
    # neurons_per_layer = 24

    pra = 0
    tau = 0
    psi_deg = +90
    # psi_deg = +90
    # psi_deg = -135
    # psi_deg = +225
    # psi_deg = +45
    # psi_deg = -45
    dim = neurons_per_layer

    ident = "baseline"
    # ident_pattern = ident + "_pra{:d}_tau{:02d}_relulayers{:03d}_neurons{:03d}"
    # ident_pattern = ident + "_pra{:d}_tau{:02d}_relulayers{:03d}_neurons{:03d}_psi{:03d}"
    ident_pattern = ident + "_pra{:d}_tau{:02d}_relulayers{:03d}_neurons{:03d}_psi{:+03d}"
    # filename = ident_pattern.format(pra, tau, num_relu_layers, neurons_per_layer) + ".nnet"
    filename = ident_pattern.format(pra, tau, num_relu_layers, neurons_per_layer, psi_deg) + ".nnet"
    fullfilename = os.path.join(networks_dir, filename)

    n = acas.nnet.NNet(fullfilename)
    ranges = np.vstack(n.ranges[:-1])
    means = np.vstack(n.means[:-1])

    maxes = np.vstack(n.maxes)
    mins = np.vstack(n.mins)

    upper_raw = (maxes - means) / ranges
    lower_raw = (mins - means) / ranges

    upper = upper_raw[:-1]
    lower = lower_raw[:-1]

    ws_raw = n.weights
    bs_raw = [np.vstack(b) for b in n.biases]

    invert_full_3d_network = True
    if invert_full_3d_network:
        input_layer_bounds = (lower_raw, upper_raw)
        decomp_raw = invert_relunet_fromwb(ws_raw, bs_raw, input_layer_bounds)
        preimages_raw = [d[0] for d in decomp_raw]

        # gss = [10, 20, 30, 40, 50, 60]
        # gss = [70, 80]
        # gss = [90, 100]
        # gss = [100, 150]
        gss = [20, 40, 60, 80, 100, 120, 140, 160, 180, 200]
        for gs in gss:
            # gs = 10
            nx = gs
            ny = gs
            nz = gs

            x_grid = np.linspace(lower_raw[0], upper_raw[0], nx)
            y_grid = np.linspace(lower_raw[1], upper_raw[1], ny)
            z_grid = np.linspace(lower_raw[2], upper_raw[2], nz)

            frac_of_cells_straddling_decision_boundary = compute_frac_of_cells_straddling_decision_boundary(x_grid, y_grid, z_grid)
            logger.info(f"{gs}, {frac_of_cells_straddling_decision_boundary}")

        # x_inds = np.arange(nx)
        # y_inds = np.arange(ny)
        # z_inds = np.arange(nz)

        # n.evaluate_network_multiple(
        # grid_inds = _outer3(x_inds, y_inds, z_inds)
        # # linear_ind =

    # psi = -90 * np.pi * 2 / 360
    proj_psi_deg = -90
    # proj_psi_deg = -180
    # proj_psi_deg = +180
    # proj_psi_deg = +90
    proj_psi_rad = proj_psi_deg * 2 * np.pi / 360
    partialize_at = proj_psi_rad
    # partialize_at_normalized = -.25
    partialize_at_normalized = (partialize_at - means[-1]) / ranges[-1]

    assert 0 == pra, "Symmetry assumption only supports pra == 0"
    if partialize_at_normalized < 0:
        input_transformation = np.diag([+1, -1, -1])
        output_transformation = np.array(([[1, 0, 0, 0, 0],
                                           [0, 0, 1, 0, 0],
                                           [0, 1, 0, 0, 0],
                                           [0, 0, 0, 0, 1],
                                           [0, 0, 0, 1, 0]]))
    else:
        input_transformation = np.eye(3)
        output_transformation = np.eye(5)

    ws, bs = partialize_network_last_input(ws_raw,
                                           bs_raw,
                                           partialize_at_normalized,
                                           input_transformation,
                                           output_transformation)
    input_layer_bounds = (lower, upper)
    decomps = invert_relunet_fromwb(ws, bs, input_layer_bounds)

    make_encounter_plot = True
    if make_encounter_plot:
        if "pgf" == fig_format:
            font_family = "serif"
            plotting.initialise_pgf_plots("pdflatex", font_family)

        plot_min = -0
        plot_max = +36e3

        intr_sign = -1 * np.sign(proj_psi_rad)

        if -1 == intr_sign:
            xlims = (-5e3, +30e3)
            ylims = (+5e3, -25e3)
        else:
            xlims = (-5e3, +30e3)
            ylims = (-5e3, +25e3)

        scl = 1.25
        fs = (3.75 * scl, 2.5 * scl)
        fig, ax = plt.subplots(1, 1, figsize=fs)

        preimages = [d[0] for d in decomps]

        encounter_plot(fig,
                       ax,
                       preimages,
                       lower,
                       upper,
                       means,
                       ranges,
                       xlims,
                       ylims,
                       intr_sign)

        utcnowstr = dt.datetime.utcnow().strftime("%Y_%m_%d_%H_%M_%S")
        # ident = "encounterplot_pra{}_tau{}_psi{:+03d}".format(pra, tau, proj_psi_deg)
        ident = "encounterplot_pra{}_tau{}_psi{:+03d}_{}".format(pra, tau, proj_psi_deg, utcnowstr)
        filepath = paths["plots"]
        fig_path = plotting.smart_save_fig(fig,
                                           ident,
                                           fig_format,
                                           filepath)
        print(f"Check out {fig_path}")
        # do_bruteforced = True
        # if do_bruteforced:
        #     fig, ax = plt.subplots(1, 1, figsize=fs)
