import timeit
import functools
import logging
import sys
from typing import Callable, Iterable, List, Tuple, Union

import numpy as np
import matplotlib
import matplotlib.pyplot as plt

import plotting
import tools

logging_level = logging.INFO
logging_format = "%(asctime)s %(process)s %(thread)s: %(message)s"
logging.basicConfig(level=logging_level, format=logging_format, stream=sys.stdout)

logger = logging.getLogger(__name__)
FigAx = Tuple[matplotlib.figure.Figure, matplotlib.axes.Axes]
Numeric = Union[float, int]

# python3 -m cProfile -o main.prof main.py
# snakeviz main.prof


def _empty_size_calc(n: int):
    uc_h = tools.unit_cube_h_repr(n)
    a = tools.nth_canonical_basis(0, dim=n).T
    b = np.array([[-1]])
    to_append = np.hstack([b, -a])
    empty_h = np.vstack([to_append, uc_h])
    empty_v = tools.h_to_v(empty_h)
    assert 0 == empty_v.size


def time_over_n(integer_grid: Iterable[int],
                one_arg_fun: Callable,
                num_reps: int) -> np.ndarray:
    """
    function_of_n: use functools to build this
    """
    ns = integer_grid
    num_ns = len(ns)
    timings = np.full((num_ns,), np.nan)
    for idx, n in enumerate(ns):
        s = functools.partial(one_arg_fun, n)
        logger.info("Starting n = {} ".format(n))
        runtime = (
                timeit.timeit(s, number=num_reps) / num_reps
        )
        timings[idx] = runtime
    return timings


def plot_timing(value_grid: Iterable[Numeric],
                timings: Iterable[float]) -> FigAx:
    fig, axs = plotting.wrapped_subplot(1, 1)
    ax = axs[0, 0]
    ax.bar(value_grid, timings)
    ax.set_xlabel("Dimension")
    ax.set_ylabel("Time (seconds) / computation")
    return fig, ax


@functools.lru_cache(maxsize=16)
def test(n: int):
    tools.unit_cube_v_repr(n)


if __name__ == "__main__":
    integer_grid = [3, 5, 10, 15, 20]
    one_arg_fun = test
    num_reps = 20
    timings = time_over_n(integer_grid, one_arg_fun, num_reps)
    fig, ax = plot_timing(integer_grid, timings)


