import logging
import os
import pprint
import random
from typing import Any, Callable, Dict, List, Tuple, Union

import torch
import numpy as np

import matplotlib
import matplotlib.pyplot as plt

import settings
import plotting
import decomp
import dnn
import tools
import data
import pytorch_models
import results
import caching
import path_config

# figure.max_open_warning
COLORS = plotting.COLORS
np.set_printoptions(linewidth=1000)

# logging_format = "%(asctime)s: %(message)s"
logging_format = "{%(func)s: %(lineno)4s: {%(asctime)s: %(message)s"

logging_level = 15
logging.basicConfig(level=logging_level,
                    format=logging_format)

logger = logging.getLogger(__name__)

Polytope = Dict[str, Any]
HRepresentation = Dict[str, Any]
VRepresentation = Dict[str, Any]
Region = List[Polytope]


def build_par_logging_string(par: Dict[str, Any]) -> str:
    design_par = par["design"]
    data_par = par["data"]
    dnn_par = par["dnn"]
    inversion_par = par["inversion"]

    to_join = []
    to_join += ["Design parameters"]
    to_join += ["    Experiment: {}".format(experiment_name)]
    to_join += ["        device: {}".format(design_par["device"])]

    to_join += ["Data parameters"]
    to_join += ["    Input Dimension: {}".format(data_par["input_dims"])]
    to_join += ["    # of classes: {}".format(data_par["num_classes"])]
    to_join += ["    n_samples: {}".format(data_par["n_samples"])]
    to_join += ["        noise: {}".format(data_par["noise"])]

    to_join += ["Fitting settings"]
    to_join += ["           Epochs: {}".format(dnn_par["epochs"])]
    to_join += ["            Optim: {}".format(dnn_par["optimizer_name"])]
    to_join += ["        Criterion: {}".format(dnn_par["criterion_name"])]

    to_join += ["Inversion settings"]
    to_join += ["            is_rational: {}".format(inversion_par["is_rational"])]
    to_join += ["         invert_classes: {}".format(inversion_par["invert_classes"])]

    par_logging_string = "\n".join(to_join)
    return par_logging_string


def np_to_torch(x: np.ndarray) -> torch.Tensor:
    return torch.from_numpy(x).type(torch.FloatTensor)


if __name__ == "__main__":
    experiment_name = "moons"
    # experiment_name = "moons_multi_layer"
    # experiment_name = "splinter4"
    # experiment_name = "splinter11"
    # experiment_name = "sphere3"
    # experiment_name = "sphere_high_dim"
    # experiment_name = "sphere_very_high_dim"
    # experiment_name = "simple_conv_image"
    # experiment_name = "simple_avgpool_and_conv2d_image"
    # experiment_name = "simple_mnist"

    seeds = settings.set_seeds()
    random.seed(seeds["random"])
    torch.manual_seed(seeds["torch"])
    np.random.seed(seeds["numpy"])

    design_par = settings.set_design_parameters(experiment_name)
    par = settings.set_experiment_parameters(design_par)

    data_par = par["data"]
    train_x, train_y = data.generate_dataset(data_par, True)
    test_x, test_y = data.generate_dataset(data_par, False)

    present_results = True

    logger.info(par)
    par_hash = caching.hash_par(par)

    par_logging_string = build_par_logging_string(par)
    logger.info(par_logging_string)

    paths = path_config.get_paths()
    results_path = paths["results"]
    cache_dir = paths['cached_calculations']

    dnn_par = par["dnn"]
    inversion_par = par["inversion"]

    calc_fun = dnn.build_dnn
    calc_args = (train_x, train_y, dnn_par)
    calc_kwargs = {}

    force_regeneraton = False
    # force_regeneraton = True

    model = caching.cached_calc(cache_dir,
                                calc_fun,
                                calc_args,
                                calc_kwargs,
                                force_regeneraton)

    test_accuracy = dnn.assess_test_accuracy(model, test_x, test_y)
    logger.info("Test accuracy: {:.3f}".format(test_accuracy))

    layers = model.layers
    layer_info = decomp.build_layer_info(layers, inversion_par)
    decomps = decomp.compute_decomps(layer_info, inversion_par)

    results_summary = results.compute_results_summary(decomps,
                                                      model,
                                                      train_x,
                                                      train_y,
                                                      test_x,
                                                      test_y,
                                                      par)
    if present_results:
        results.present_results_summary(results_summary,
                                        train_x,
                                        train_y,
                                        par)

    confirm_volume_conservaton = True

    input_layer_bounds = par["inversion"]["input_layer_bounds"]
    has_finite_input_bounds = np.all(np.isfinite(input_layer_bounds[0])) and \
                              np.all(np.isfinite(input_layer_bounds[1]))
    if confirm_volume_conservaton and has_finite_input_bounds:
        num_decomps = len(decomps)
        preimage_volume_distribution = [None] * num_decomps
        for idx_decomp, decomp in enumerate(decomps):
            # idx_decomp = 0; decomp = decomps[idx]
            preimage_partition = decomp[0]
            num_preimages = len(preimage_partition)
            volumes = np.full((num_preimages,), np.nan)
            for idx_p, p in enumerate(preimage_partition):
                # idx_p = 3418; p = preimage_partition[idx_p]
                if p["v"]["is_empty"]:
                    volumes[idx_p] = 0
                else:
                    vertices = p["v"]["vertices"]
                    assert np.all(1 == vertices[:, 0])
                    verts = vertices[:, 1:]
                    vol = tools.compute_hull_volume(verts)
                    # print(idx_p, vertices, vol)
                    volumes[idx_p] = vol

            preimage_volume_distribution[idx_decomp] = volumes

        preimage_volumes = np.full((num_decomps,), np.nan)
        for idx_vd, vd in enumerate(preimage_volume_distribution):
            preimage_volumes[idx_vd] = np.sum(vd)

        print("Total volume: {}".format(np.sum(preimage_volumes)))

        for idx, piv in enumerate(preimage_volumes):
            print("Class {}: {}".format(idx, piv))
