import logging
import os
from typing import Any, Callable, Dict, List, Tuple, Union

import sklearn.metrics
import sklearn.datasets
import numpy as np
import scipy.linalg
import seaborn

import matplotlib
import matplotlib.pyplot as plt

import torch
import torchvision

import plotting
import tools
import data
import pytorch_models
import results
import caching
import path_config


def set_design_parameters(experiment_name: str) -> Dict[str, Any]:
    is_image = False
    input_dims = [2]

    epochs = 500
    hidden_layer_widths = [10]

    if experiment_name.startswith("sphere"):
        noise = .05
    else:
        noise = .2

    # optimizer_name = "adam"
    optimizer_name = "sgd"

    if "moons" == experiment_name:
        num_classes = 2
        generating_function_name = "moons"
        model_name = "simple_relunet"
        # hidden_layer_widths = [6]
        # hidden_layer_widths = [7]
        # hidden_layer_widths = [8]
        # hidden_layer_widths = [9]
        # hidden_layer_widths = [10]
        # hidden_layer_widths = [11]
        hidden_layer_widths = [12]
        # hidden_layer_widths = [14]
        # hidden_layer_widths = [16]
        # hidden_layer_widths = [18]
        # hidden_layer_widths = [20]
        input_layer_bounds = (-3 * np.ones((2, 1)),
                              +3 * np.ones((2, 1)))
        noise = .2
        epochs = 1000
        # epochs = 1
        optimizer_name = "adam"
    elif "moons_multi_layer" == experiment_name:
        num_classes = 2
        model_name = "simple_relunet"
        generating_function_name = "moons"
        # hidden_layer_widths = [4, 4, 4, 4, 4]
        # hidden_layer_widths = [4, 4, 4, 4]
        # hidden_layer_widths = [5, 5, 5, 5]
        # hidden_layer_widths = [6, 6, 5, 5]
        # hidden_layer_widths = [5, 4, 4]
        # hidden_layer_widths = [5, 4, 3]
        # hidden_layer_widths = [4, 4]
        # hidden_layer_widths = [5, 4]
        # hidden_layer_widths = [5, 5]
        # hidden_layer_widths = [5, 5, 5]
        # hidden_layer_widths = [6, 6, 6]
        # hidden_layer_widths = [7, 7, 7]
        # hidden_layer_widths = [8, 8]
        # hidden_layer_widths = [9, 9]
        # hidden_layer_widths = [10, 10]
        hidden_layer_widths = [10, 10, 10]

        input_layer_bounds = (-3 * np.ones((2, 1)),
                              +3 * np.ones((2, 1)))
        noise = .2
        epochs = 1000
        optimizer_name = "adam"
    elif "splinter4" == experiment_name:
        num_classes = 4
        generating_function_name = "splinters"
        model_name = "simple_relunet"
        hidden_layer_widths = [12]
        input_layer_bounds = (-np.full((2, 1), np.inf),
                              +np.full((2, 1), np.inf))

    elif "splinter11" == experiment_name:
        num_classes = 11
        generating_function_name = "splinters"
        model_name = "simple_relunet"
        hidden_layer_widths = [12]
        input_layer_bounds = (-np.full((2, 1), np.inf),
                              +np.full((2, 1), np.inf))
    elif "checkerboard3" == experiment_name:
        num_classes = 3
        generating_function_name = "checkerboard"
        model_name = "simple_relunet"
        hidden_layer_widths = [12]
        input_layer_bounds = (-np.full((2, 1), np.inf),
                              +np.full((2, 1), np.inf))
    elif "sphere3" == experiment_name:
        num_classes = 3
        generating_function_name = "spheres"
        model_name = "simple_relunet"
        hidden_layer_widths = [12]
        input_layer_bounds = (-np.full((2, 1), np.inf),
                              +np.full((2, 1), np.inf))
    elif "sphere_high_dim" == experiment_name:
        input_dim = 10
        input_dims = [input_dim]
        num_classes = 3
        generating_function_name = "spheres"
        model_name = "simple_relunet"
        hidden_layer_widths = [12]
        input_layer_bounds = (-np.full((input_dim, 1), np.inf),
                              +np.full((input_dim, 1), np.inf))

    elif "sphere_very_high_dim" == experiment_name:
        input_dim = 400
        input_dims = [input_dim]
        num_classes = 3
        generating_function_name = "spheres"
        model_name = "simple_relunet"
        hidden_layer_widths = [12]
        input_layer_bounds = (-np.full((input_dim, 1), np.inf),
                              +np.full((input_dim, 1), np.inf))
    elif "simple_conv_image" == experiment_name:
        # input_dims = [3, 4, 5]
        input_dims = [3, 2, 4]
        input_dim = np.prod(input_dims)

        input_layer_bounds = (-np.full((input_dim, 1), np.inf),
                              +np.full((input_dim, 1), np.inf))
        is_image = True
        num_classes = 2
        generating_function_name = "simple_image"
        model_name = "simple_relunet_with_input_conv"

    elif "simple_avgpool_and_conv2d_image" == experiment_name:
        input_dims = [3, 7, 7]
        input_dim = np.prod(input_dims)

        input_layer_bounds = (-np.full((input_dim, 1), np.inf),
                              +np.full((input_dim, 1), np.inf))
        num_classes = 3
        is_image = True
        generating_function_name = "simple_image"
        model_name = "simple_relunet_with_input_conv_and_avgpool"
    elif "simple_mnist" == experiment_name:
        generating_function_name = "mnist"

        is_image = True
        input_dims = [1, 28, 28]
        input_dim = np.prod(input_dims)
        input_layer_bounds = (np.zeros((input_dim, 1)),
                              np.ones((input_dim, 1)))
        num_classes = 10
        epochs = 5
        # epochs = 20
        # model_name = "simple_mnist_classifier"
        model_name = "simple_relunet"
        # hidden_layer_widths = [6]
        hidden_layer_widths = [5]
    else:
        raise ValueError("Unknown experiment: {}".format(experiment_name))

    want_cuda = True
    use_cuda = want_cuda and torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    do_inversion = True

    # cache_fit_dnn = experiment_name in ["simple_mnist", "moons"]
    # cache_fit_dnn = False
    cache_fit_dnn = True
    # criterion_name = "hinge"
    criterion_name = "cross_entropy"

    small_experiments = ["moons",
                         "moons_multi_layer",
                         "checkerboard3",
                         "splinter4",
                         "splinter11",
                         "sphere3"]

    if experiment_name in small_experiments:
        need_initial_v = True
    else:
        need_initial_v = False

    if experiment_name == "simple_mnist":
        need_initial_v = True

    design_par = {
        "cache_fit_dnn": cache_fit_dnn,
        "criterion_name": criterion_name,
        "device": device,
        "do_inversion": do_inversion,
        "epochs": epochs,
        "experiment_name": experiment_name,
        "generating_function_name": generating_function_name,
        "hidden_layer_widths": hidden_layer_widths,
        "input_dims": input_dims,
        "input_layer_bounds": input_layer_bounds,
        "is_image": is_image,
        "model_name": model_name,
        "need_initial_v": need_initial_v,
        "noise": noise,
        "num_classes": num_classes,
        "optimizer_name": optimizer_name,
        "use_cuda": use_cuda
    }
    return design_par


def set_seeds() -> Dict[str, int]:
    # s = 1
    s = 2
    # s = 3
    # s = 4
    seeds = {
        "random": s,
        "torch": s,
        "numpy": s
    }
    return seeds


def set_experiment_parameters(design_par: Dict[str, Any]) -> Dict[str, Any]:
    paths = path_config.get_paths()

    # n_samples = 30000
    # n_samples = 500
    n_samples = 1000
    # n_samples = 3000
    noise = design_par["noise"]

    data_par = {
        "input_dims": design_par["input_dims"],
        "generating_function_name": design_par["generating_function_name"],
        "n_samples": n_samples,
        "noise": noise,
        "num_classes": design_par["num_classes"],
    }

    hidden_layer_widths = design_par["hidden_layer_widths"]
    input_dims = design_par["input_dims"]
    output_width = design_par["num_classes"]
    model_name = design_par["model_name"]

    criterion_name = design_par["criterion_name"]

    # do_caching = False
    do_caching = True

    optimizer_name = design_par["optimizer_name"]
    if optimizer_name == "adam":
        optim_kwargs = {"lr": 0.005, "betas": (0.9, 0.999)}
    elif optimizer_name == "sgd":
        optim_kwargs = {"lr": 0.01, "momentum": 0.5}
    else:
        raise ValueError("Unknown optimizer")

    device = design_par["device"]
    # batch_size = 64
    batch_size = 128
    shuffle = True
    epochs = design_par["epochs"]

    dataloader_kwargs = {"num_workers": 1,
                         "pin_memory": True} if design_par["use_cuda"] else {}

    # num_par = sum([sum([_.numel() for _ in t.parameters()]) for t in layer_list])

    log_every_epoch = 50 if (epochs > 10) else 1
    cache_fit_dnn = design_par["cache_fit_dnn"]

    dnn_par = {
        "model_name": model_name,
        "input_dims": input_dims,
        "hidden_layer_widths": hidden_layer_widths,
        "output_width": output_width,
        "batch_size": batch_size,
        "cache_fit_dnn": cache_fit_dnn,
        "criterion_name": criterion_name,
        "dataloader_kwargs": dataloader_kwargs,
        "device": device,
        "do_caching": do_caching,
        "epochs": epochs,
        "log_every_epoch": log_every_epoch,
        "optimizer_name": optimizer_name,
        "optim_kwargs": optim_kwargs,
        "shuffle": shuffle
    }

    is_large = (len(hidden_layer_widths) > 1 and np.prod(hidden_layer_widths) > 50) or \
               (len(hidden_layer_widths) == 1 and np.prod(hidden_layer_widths) > 20)

    if 'simple_mnist' == design_par["experiment_name"]:
        # invert_classes = list(range(2))
        # invert_classes = [1, 2]
        invert_classes = [3, 4]
    elif design_par["experiment_name"] in ["moons", "moons_multi_layer"] and is_large:
        invert_classes = [0]
    else:
        invert_classes = list(range(design_par["num_classes"]))
    is_rational = False

    need_initial_v = design_par["need_initial_v"]
    input_layer_bounds = design_par["input_layer_bounds"]

    # desired_margin = -.1
    desired_margin = 0.0

    cache_inversion = True
    # cache_inversion = not (experiment_name in small_experiments)
    do_inversion = design_par["do_inversion"]
    inversion_par = {
        "do_inversion": do_inversion,
        "cache_inversion": cache_inversion,
        "desired_margin": desired_margin,
        "need_initial_v": need_initial_v,
        "input_layer_bounds": input_layer_bounds,
        "invert_classes": invert_classes,
        "is_rational": is_rational
    }

    do_2d_plot = [2] == design_par["input_dims"]
    is_image = design_par["is_image"]
    do_terminal_decomp_plot = (2 == output_width)

    # save_plots = False
    # plot_format = "png"

    save_plots = True
    plot_format = "pgf"
    if save_plots:
        assert "pgf" == plot_format, "Do not want to save in a format besides pgf"

    plots_dir = paths["plots"]

    additional_ident = None
    save_results = True
    # plot_scale = 2.25

    plot_scale = 1.35
    big_plot_scale = 2.50

    results_par = {
        "additional_ident": additional_ident,
        "big_plot_scale": big_plot_scale,
        "do_2d_plot": do_2d_plot,
        "do_terminal_decomp_plot": do_terminal_decomp_plot,
        "is_image": is_image,
        "plots_dir": plots_dir,
        "save_plots": save_plots,
        "save_results": save_results,
        "plot_format": plot_format,
        "plot_scale": plot_scale,
    }

    parameters = {
        "design": design_par,
        "data": data_par,
        "dnn": dnn_par,
        "inversion": inversion_par,
        "results": results_par,
    }
    return parameters


if __name__ == "__main__":
    experiment_name = "splinter4"
    design_par = set_design_parameters(experiment_name)
