import logging
import os
import tempfile
import random
import copy
import pickle
import functools
import timeit

from typing import Any, Callable, Dict, List, Tuple, Union

import pandas as pd
import torch
import numpy as np

import matplotlib
import matplotlib.pyplot as plt

import settings
import plotting
import decomp
import dnn
import tools
import data
import pytorch_models
import results
import caching
import path_config

# figure.max_open_warning
COLORS = plotting.COLORS
np.set_printoptions(linewidth=1000)

# logging_format = "%(asctime)s: %(message)s"
logging_format = "{%(func)s: %(lineno)4s: {%(asctime)s: %(message)s"

logging_level = 15
logging.basicConfig(level=logging_level,
                    format=logging_format)

logger = logging.getLogger(__name__)


def get_experiment_results(experiment_base_design_par: Dict[str, Any],
                           hidden_layer_widths: List[list]) -> Dict[str, Any]:
    num_hidden_layer_weights = len(hidden_layer_widths)

    times = np.empty((num_hidden_layer_weights,))
    storages = np.empty((num_hidden_layer_weights,))
    accuracies = np.empty((num_hidden_layer_weights,))
    num_pars = np.empty((num_hidden_layer_weights,))

    experiment_base_par = settings.set_experiment_parameters(experiment_base_design_par)
    data_par = experiment_base_par["data"]

    train_x, train_y = data.generate_dataset(data_par, True)
    test_x, test_y = data.generate_dataset(data_par, False)

    for h_idx, h in enumerate(hidden_layer_widths):
        # h_idx = 0; h = hidden_layer_widths[h_idx]
        print(h)
        idx_design_par = copy.deepcopy(experiment_base_design_par)
        idx_design_par['hidden_layer_widths'] = h

        idx_par = settings.set_experiment_parameters(idx_design_par)
        dnn_par = idx_par["dnn"]

        calc_fun = dnn.build_dnn
        calc_args = (train_x, train_y, dnn_par)

        model = caching.cached_calc(cache_dir,
                                    calc_fun,
                                    calc_args,
                                    calc_kwargs={},
                                    force_regeneraton=False)

        accuracies[h_idx] = dnn.assess_test_accuracy(model, test_x, test_y)
        inversion_par = idx_par["inversion"]
        calc_fun = get_inversion_runtime
        calc_args = (model, inversion_par, num_reps)
        runtime = caching.cached_calc(cache_dir,
                                      calc_fun,
                                      calc_args,
                                      calc_kwargs={},
                                      force_regeneraton=False)
        times[h_idx] = runtime

        calc_fun = get_pickled_size
        calc_args = (idx_design_par,)
        s = caching.cached_calc(cache_dir,
                                calc_fun,
                                calc_args,
                                calc_kwargs={},
                                force_regeneraton=False)
        storages[h_idx] = s
        num_pars[h_idx] = sum([sum([_.numel() for _ in t.parameters()]) for t in model.layers])

    experiment_results = {
        "accuracies": accuracies,
        "hidden_layer_widths": hidden_layer_widths,
        "num_pars": num_pars,
        "storages": storages,
        "times": times,
    }
    return experiment_results


def get_inversion_runtime(model: pytorch_models.Net,
                          inversion_par: Dict[str, Any],
                          num_reps: int) -> float:
    layers = model.layers
    layer_info = decomp.build_layer_info(layers, inversion_par)
    s = functools.partial(decomp.compute_decomps,
                          layer_info=layer_info,
                          inversion_par=inversion_par)
    inversion_runtime = timeit.timeit(s, number=num_reps) / num_reps
    return inversion_runtime


def get_pickled_size(idx_design_par: Dict[str, Any]) -> float:
    base_par = settings.set_experiment_parameters(idx_design_par)

    dnn_par = base_par["dnn"]
    inversion_par = base_par["inversion"]
    data_par = base_par["data"]

    train_x, train_y = data.generate_dataset(data_par, True)

    model = dnn.build_dnn(train_x, train_y, dnn_par)
    layers = model.layers
    layer_info = decomp.build_layer_info(layers, inversion_par)
    d = decomp.compute_decomps(layer_info, inversion_par)

    with tempfile.NamedTemporaryFile(mode='wb') as tmp_file:
        # pickle.dump(Data, tmp_file)
        pickle.dump(d, tmp_file, protocol=pickle.HIGHEST_PROTOCOL)
        tmp_file.flush()
        pickled_size = os.path.getsize(tmp_file.name)
    return pickled_size


def layerizer(lws: List[int]) -> str:
    lws_str = [str(_) for _ in lws]
    if 1 == len(lws):
        ll = lws_str[0]
    else:
        ll = " \\rightarrow ".join(lws_str)
    ll = "$" + ll + "$"
    return ll


if __name__ == "__main__":
    seeds = settings.set_seeds()
    random.seed(seeds["random"])
    torch.manual_seed(seeds["torch"])
    np.random.seed(seeds["numpy"])

    paths = path_config.get_paths()
    results_path = paths["results"]
    cache_dir = paths['cached_calculations']

    num_reps = 2
    # number_of_repetitions = 10
    # experiment_names = ["moons", "splinter11", "sphere_very_high_dim"]
    experiment_names = ["moons"]
    # experiment_names = ["splinter11"]
    num_experiments = len(experiment_names)
    # hidden_layer_widths = [[8], [12], [16], [20],
    #                        [4, 4], [6, 6], [8, 8],
    #                        [4, 4, 4], [6, 6, 6]]
    # hidden_layer_widths = [[3, 3], [8], [12]]
    hidden_layer_widths = [[8], [12], [16],
                           [4, 4], [6, 6], [8, 8],
                           [4, 4, 4], [6, 6, 6],
                           [4, 4, 4, 4]]  # , [8, 8, 8]
    # hidden_layer_widths = [[8, 8, 8]]

    force_regeneraton = False
    experiment_results = {}

    for e_idx, experiment_name in enumerate(experiment_names):
        # e_idx = 0; experiment_name = experiment_names[e_idx]
        experiment_base_design_par = settings.set_design_parameters(experiment_name)

        calc_fun = get_experiment_results
        calc_args = (experiment_base_design_par, hidden_layer_widths)
        force_regeneraton = True
        experiment_results[experiment_name] = caching.cached_calc(cache_dir,
                                                                  calc_fun,
                                                                  calc_args,
                                                                  calc_kwargs={},
                                                                  force_regeneraton=force_regeneraton)

    columns = ["hidden_layer_widths", "accuracies", "num_pars", "storages", "times"]
    renamer = {"hidden_layer_widths": "Hidden layer widths",
               "accuracies": "Accuracy",
               "num_pars": "\# parameters",
               "storages": "Storage (MB)",
               "times": "Time (s)"}
    formatters = []
    present_results = True
    if present_results:
        for e_idx, experiment_name in enumerate(experiment_names):
            # e_idx = 0; experiment_name = experiment_names[e_idx]
            experiment_result = experiment_results[experiment_name]
            er_frame = pd.DataFrame(experiment_result).loc[:, columns]
            er_frame.rename(columns=renamer, inplace=True)
            formatters = {
                renamer["hidden_layer_widths"]: layerizer,
                renamer["accuracies"]: "{:.3f}".format,
                renamer["num_pars"]: "{:.0f}".format,
                renamer["storages"]: lambda x: "{:.3f}".format(x/ 1000000),
                renamer["times"]: "{:.3f}".format
            }
            label = "fig:{}_timings".format(experiment_name)
            # https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.to_latex.html
            to_print = er_frame.to_latex(formatters=formatters,
                                         index=False,
                                         escape=False,
                                         label=label)
            print(to_print)
