import torch.nn as nn
import torch
import numpy as np
import nninfo
import copy
from typing import Union, Tuple, List
from nninfo.quantization import quantizer_list_factory

Limits = Union[Tuple[float, float], str]

# softmax is not a functional, because then we would have to set the dimension for each run
ACTIV_FUNCS_PYTORCH = {
    "input": None,
    "linear": lambda x: x,
    "dropout": lambda x: x,
    "maxpool2d": lambda x: x,
    "maxpool2d_relu": torch.relu,
    "relu": torch.relu,
    "conv2d": lambda x: x,
    "conv2d_tanh": torch.tanh,
    "conv2d_relu": torch.relu,
    "flatten": lambda x: x,
    "tanh": torch.tanh,
    "hardtanh": nn.Hardtanh(),
    "sigmoid": torch.sigmoid,
    "log_softmax": torch.nn.functional.log_softmax,
    "softmax_output": lambda x: x,
}

ACTIV_FUNCS_BINNING_LIMITS = {
    "linear": "open",
    "dropout": "open",
    "relu": "semi-open",
    "maxpool2d": "open",
    "maxpool2d_relu": "semi-open",
    "tanh": (-1.0, 1.0),
    "conv2d": "open",
    "conv2d_tanh": (-1.0, 1.0),
    "conv2d_relu": "semi-open",
    "flatten": (-1.0, 1.0),
    "hardtanh": (-1.0, 1.0),
    "sigmoid": (0.0, 1.0),
    "softmax": (0.0, 1.0),
    "softmax_output": (0.0, 1.0),
}

INITIALIZERS_PYTORCH = {
    "xavier": nn.init.xavier_uniform_,
    "he_kaiming": nn.init.kaiming_uniform_,
    "he_kaiming_normal": nn.init.kaiming_normal_,
}


class NeuralNetwork(nninfo.exp_comp.ExperimentComponent, nn.Module):
    """
    Model that is trained and analysed.

    CUDA acceleration is not implemented yet, but will certainly be possible in the future.
    """

    def __init__(
            self,
            net_layer_sizes,
            net_activ_funcs,
            init_str,
            bias=True,
            net_type="feedforward",
            **kwargs
    ):
        """
        Creates a new instance of NeuralNetwork and sets all structural parameters of the model.

        Important comment: the external indexing of the layers is 1,...,n for convenience.
        However, I could not find a way to use this indexing also for the inner structure. Maybe
        we might change that in the future to avoid errors.

        Args:
            net_layer_sizes (list of int):  Instructions for layer sizes consisting of
                neuron number for each layer.
                Structure: [input_size, lay1, ..., output_size]
            net_activ_funcs (list of str): Instructions for activation functions for each layer.
                The input_layer is a separate layer and should get 'input' as activ_func.
                Allowed activation functions are 'input', 'relu', 'tanh', 'sigmoid', 'softmax'.
            init_str (str): name of the weight initializer to be used for all weights of the network.
            bias (bool): (optional) Whether there is a bias parameter for each neuron. True as default.
            net_type (str): (optional) Type of network. Currently, only "feedforward" is implemented.
                In the future, it might be interesting to look at other types of networks,
                for example "convolutional", "bayesian", etc. "feedforward" as default.

        Keyword Args:
            noise_stddev (float): in case of a noisy neural network
        """
        # call pytorch super class
        super(NeuralNetwork, self).__init__()

        assert net_type == "feedforward", "Network types other than feedforward are currently not supported."

        self._layer_sizes = net_layer_sizes
        self._activ_func_str = net_activ_funcs
        self._bias = bias

        self._params = kwargs

        self.n_layers = len(net_layer_sizes) - 1

        self._activ_funcs = []
        # leaves out the 'input'

        for i in range(1, len(net_activ_funcs)):
            self._activ_funcs.append(ACTIV_FUNCS_PYTORCH[net_activ_funcs[i]])

        # Layers are stored in a ModuleList because it is more flexible than nn.Sequential
        # this is noconvt really true any more. A nn.Sequential is just fine
        # but older experiments cannot be loaded if you change the naming here
        self.layers = nn.ModuleList()  # nn.Sequential()
        for i in range(0, self.n_layers):
            if net_activ_funcs[i + 1].startswith('conv2d'):
                if i == 0:
                    self.layers.append(
                        nn.Conv2d(
                            3, self._layer_sizes[i + 1], 3, bias=bias, padding=(1, 1))
                    )
                else:
                    self.layers.append(
                        nn.Conv2d(
                            self._layer_sizes[i], self._layer_sizes[i + 1], 3, bias=bias, padding=(1, 1))
                    )
            elif net_activ_funcs[i + 1] == 'maxpool2d':
                self.layers.append(
                    nn.MaxPool2d(2)
                )
            elif net_activ_funcs[i + 1] == 'maxpool2d_relu':
                self.layers.append(
                    nn.MaxPool2d(2)
                )
            elif net_activ_funcs[i + 1].startswith('flatten'):
                self.layers.append(
                    nn.Flatten(1)
                )
            elif net_activ_funcs[i + 1] == 'dropout':
                self.layers.append(
                    nn.Dropout(0.2)
                )
            else:
                self.layers.append(
                    nn.Linear(
                        self._layer_sizes[i], self._layer_sizes[i + 1], bias=bias)
                )
            # self.layers.add_module("L"+str(i)+"-"+str(i+1), nn.Linear(self._layer_sizes[i], self._layer_sizes[i+1], bias=bias))

        self._initializer = None
        self._init_str = init_str
        self._activities = dict()

        self.init_weights()
       
    def get_network_settings(self):
        """
        Is called when the experiment saves all its components at the beginning of an experiment.

        Returns:
            dict: Dictionary with all the network settings necessary to create the network again
                for rerunning and later investigation.
        """
        param_dict = {
            "net_layer_sizes": self._layer_sizes,
            "net_activ_funcs": self._activ_func_str,
            "bias": self._bias,
            "net_type": "feedforward",
            "init_str": self._init_str,
        }

        # if self._net_type == "noisy_feedforward":
        #    param_dict["noise_stddev"] = self._params["noise_stddev"]

        return copy.deepcopy(param_dict)

    def get_layer_structure_dict(self):
        layer_dict = dict()
        for i, size in enumerate(self._layer_sizes):
            if i == 0:
                layer_label = "X"
            else:
                layer_label = "L" + str(i)
            neuron_id_list = []
            for n in range(size):
                neuron_id_list.append((layer_label, (n + 1,),))
            layer_dict[layer_label] = neuron_id_list

        layer_label = "Y"
        neuron_id_list = []
        for n in range(self._layer_sizes[-1]):
            neuron_id_list.append((layer_label, (n + 1,),))
        layer_dict[layer_label] = neuron_id_list
        return layer_dict

    def get_binning_limits(self):
        """
        returns: {neuron_id -> (low, high) or "binary"}
        """
        structure = self.get_layer_structure_dict()
        limit_dict = dict()
        for layer_label in structure:
            for neuron_id in structure[layer_label]:
                if layer_label == "X" or layer_label == "Y":
                    limit_dict[neuron_id] = self.parent.task.get_binning_limits(
                        layer_label
                    )
                elif layer_label.startswith("L"):
                    i = int(layer_label.lstrip("L"))
                    limit_dict[neuron_id] = ACTIV_FUNCS_BINNING_LIMITS[
                        self._activ_func_str[i]
                    ]
                else:
                    raise NotImplementedError(
                        "Value " + layer_label + "not allowed for layer_label."
                    )
        return limit_dict

    def get_limits_list(self) -> List[Limits]:
        """
        Currently returns None for the input Limits
        returns: [(low, high) or "binary"]
        """
        return [None] + [nninfo.model.ACTIV_FUNCS_BINNING_LIMITS[self._activ_func_str[layer]] for layer in
                         range(1, self.n_layers + 1)]

    def forward(self, x, quantizers):
        """

        Forward pass for the network, given a batch.

        Args:
            x (torch tensor): batch from dataset to be fed into the network
details
        Returns:
            torch tensor: output of the network
        """

        x = quantizers[0](x)

        for i in range(self.n_layers):
            x = self.layers[i](x)
            x = self._activ_funcs[i](x)
            x = quantizers[i + 1](x)

        return x

    def probe(self, x, quantizers):
        """

        Collects activations and output for a forward run.

        Args:
            x (torch tensor): batch from dataset to be fed into the network

        Returns:
            torch tensor: output of the network, activities dict
        """

        activities = {}

        x = quantizers[0](x)

        for i in range(self.n_layers):
            x = self.layers[i](x)
            x = self._activ_funcs[i](x)

            if self._activ_func_str[i + 1] == 'softmax_output':
                x = torch.softmax(x, dim=1)

            x = quantizers[i + 1](x)

            activities["L" + str(i + 1)] = x.detach().numpy()

        return x, activities

    def extract_activations(self, x, quantizer_params):
        """
        Extracts the activities using the input given. Used by Analysis. Outputs
        activities together in a dictionary (because of variable sizes of the layers).

        Args:
            x (torch tensor): batch from dataset to calculate the activities on. Typically
                feed the entire dataset in one large batch.
            before_noise (bool): In a noisyNN, sample before or after applying noise

        Returns:
            dict: dictionary with each neuron_id as key,
                labeled "L1",..,"L<n>",
                where "L1" corresponds to the output of the first layer and "L<n>" to
                the output of the final layer. Notice that this indexing is different
                from the internal layer structures indices (they are
                uncomfortable to change).
        """
        test_quantizers = quantizer_list_factory(
            quantizer_params, self.get_limits_list())

        with torch.no_grad():
            _, activities = self.probe(x, quantizers=test_quantizers)

        return activities

    def init_weights(self, randomize_seed=False):
        """
        Initialize the weights using the init_str that was set at initialization
        (one of the initializers provided in INITIALIZERS_PYTORCH).

        Args:
            randomize_seed (bool): If true, the torch seed is reset before the initialization.
        """

        self._initializer = INITIALIZERS_PYTORCH[self._init_str]
        if randomize_seed:
            torch.seed()
        self.apply(self._init_weight_fct)

    def _init_weight_fct(self, m):
        if isinstance(m, nn.Linear):
            if self._init_str == "xavier":
                self._initializer(m.weight, gain=5./3.)
            else:
                self._initializer(m.weight)
        elif isinstance(m, nn.Conv2d):
            if self._init_str == "xavier":
                self._initializer(m.weight, gain=np.sqrt(2))
            else:
                self._initializer(m.weight)
        else:
            print("warning! not a linear layer")
        #try:
        #    import matplotlib.pyplot as plt
        #    plt.hist(m.weight.flatten().detach().numpy())
        #    plt.title(str(type(m)))
        #    plt.show()
        #except Exception as e:
        #    print(e)

    def __call__(self, x, quantizers):
        return self.forward(x, quantizers)

    def get_input_output_dimensions(self):
        input_dim = self._layer_sizes[0]
        output_dim = self._layer_sizes[-1]
        return input_dim, output_dim

    def neuron_ids(self, only_real_neurons=False):
        """
        Create a simple list of all nodes of the network
        (including input, target, bias nodes).

        Args:
            only_real_neurons: Whether you want to only include neurons
                whose ids begin with 'L'. Default is False

        Returns:
            list: neuron ids
        """
        names = []
        if not only_real_neurons:
            names.append(("B", (1,)))
        for layer_name, neurons in self.get_layer_structure_dict().items():
            if (not only_real_neurons) or (
                only_real_neurons and layer_name.startswith("L")
            ):
                for neuron in neurons:
                    names.append(neuron)
        return names

    def connectome(self):
        """
        Returns:
            an empty connectome matrix
        """
        neuron_list = self.neuron_ids()
        connectome = [[]]
        connectome[0].append("input_neuron_ids")
        connectome[0].extend(neuron_list)
        for neuron in neuron_list:
            connectome.append([neuron])
            connectome[-1].extend([float(-1) for i in range(len(neuron_list))])
        return connectome

    def trainable_parameters(self):
        """
        Create a list of all trainable parameters.
        Ugly code still.

        Returns:
            list: List of all trainable parameters.
                dim 0: parameters
                dim 1: input neuron_id, output neuron_id, parameter_id
        """
        connectome = self.connectome()
        param_list = []
        param_index = 0
        for name, param in self.named_parameters():
            if param.requires_grad:
                _, internal_index, wb = name.split(".")
                layer_index = int(internal_index) + 1
                if wb == "weight":
                    for i in range(param.shape[0]):
                        for j in range(param.shape[1]):
                            input_layer, output_layer = (
                                "L" + str(layer_index - 1),
                                "L" + str(layer_index),
                            )
                            if input_layer == "L0":
                                input_layer = "X"
                            k = connectome[0].index((input_layer, (j + 1,)))
                            l = connectome[0].index((output_layer, (i + 1,)))
                            connectome[k][l] = param_index
                            param_list.append(
                                [connectome[0][k], connectome[0][l], param_index]
                            )
                            param_index += 1
                elif wb == "bias":
                    for i in range(param.shape[0]):
                        k = connectome[0].index(("B", (1,)))
                        l = connectome[0].index(("L" + str(layer_index), (i + 1,)))
                        connectome[k][l] = param_index
                        param_list.append(
                            [connectome[0][k], connectome[0][l], param_index]
                        )
                        param_index += 1
        return param_list


class NoisyNeuralNetwork(NeuralNetwork):
    def forward(self, x, save_activities=False, save_before_noise=False, quantizer=None, apply_output_softmax=False):
        for i in range(self.n_layers):
            x = self.layers[i](x)
            x = self._activ_funcs[i](x)

            if apply_output_softmax and self._activ_func_str[i+1] == 'softmax_output':
                x = torch.softmax(x, dim=1)
            
            # add gaussian noise to the layer with stddev noise_stddev
            if save_activities and save_before_noise:
                self._activities["L" + str(i + 1)] = x.detach().numpy()

            #if i != self.n_layers - 1:
            limits = ACTIV_FUNCS_BINNING_LIMITS[self._activ_func_str[i+1]]
            sampled_noise = torch.empty(x.shape).normal_(
                mean=0, std=self._params["noise_stddev"] * (limits[1]-limits[0])
            )
            x = x + sampled_noise

            if save_activities and not save_before_noise:
                self._activities["L" + str(i + 1)] = x.detach().numpy()

        return x

    def get_network_settings(self):
        param_dict = super(NoisyNeuralNetwork, self).get_network_settings()
        param_dict["noise_stddev"] = self._params["noise_stddev"]
        return param_dict
