"""Module containing code for experiment described in the paper"""

import numpy as np
import pandas as pd
import torch
from torch import nn
import torchvision
from scipy.linalg import qr


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Running module.py with device: {device}")

class MLP(torchvision.ops.MLP):
    """This class defines the MLP with added utilities."""

    def __init__(self, in_channels, hidden_channels, *
                 , bias=False
                 , skip_connections=None
                 , activation_layer=nn.ReLU):
        """Initializes the MLP with the given parameters.

        Args:
            in_channels (int): Number of input channels.
            hidden_channels (list): List of integers representing the number of hidden channels
                for each layer, last number is the output.
            bias (bool, optional): Whether to use bias in the layers. Defaults to False.
            skip_connections (list, optional): List of tuples representing skip connections.
                [(i, j)] will add the output of hidden layer i to the output of hidden layer j.
                Defaults to None.
        """
        super().__init__(in_channels, hidden_channels, bias=bias, activation_layer=activation_layer)
        self.in_channels = in_channels
        self.hidden_channels = hidden_channels
        self.skip_connections = np.array(skip_connections) if skip_connections is not None else None
        self.bias = bias
        self.n_biases = None
        self.n_weights = None
        self.n_skip_connections = None
        self.B = self._init_boundaryMatrix()  # pylint: disable=invalid-name
        self._fill_boundaryMatrix()
        self.t_skip = torch.ones(self.n_skip_connections, requires_grad=False)
        self.df_pruning = None

    def create_pruningOrders(self):
        """Creates the pruning orders for the neurons in the network."""
        l_output_norms, l_input_norms = self.compute_squaredlayerNorms()
        l_output_norms, l_input_norms = l_output_norms[1:-1], l_input_norms[1:-1]
        l_mul = [l_output_norms[i]*l_input_norms[i] for i in range(len(l_output_norms))]
        l_max = [np.maximum(l_output_norms[i],l_input_norms[i]) for i in range(len(l_output_norms))]
        d = {"layer": [], "neuron": [], "norm_mul": [], "norm_max": []}
        for i_layer, (a_mul, a_max) in enumerate(zip(l_mul, l_max)):
            d["layer"].extend([i_layer]*len(a_mul))
            d["neuron"].extend(range(len(a_mul)))
            d["norm_mul"].extend(a_mul)
            d["norm_max"].extend(a_max)
        df_pruning = pd.DataFrame(d)
        df_pruning["order_prune_mul"] = df_pruning["norm_mul"].rank()
        df_pruning["order_prune_max"] = df_pruning["norm_max"].rank()
        self.df_pruning = df_pruning

    def prune(self, prune_layer, prune_neuron):
        """Prune the neuron at the given layer and neuron index."""
        i_layer = 0
        for _, module in enumerate(self):
            if isinstance(module, nn.Linear):
                layer = module
                if i_layer == prune_layer:
                    layer.weight.data[prune_neuron, :] = 0  # prune neuron input
                    if layer.bias is not None:
                        layer.bias.data[prune_neuron] = 0
                if i_layer == prune_layer + 1:
                    layer.weight.data[:, prune_neuron] = 0 # prune neuron output
                i_layer += 1

    def linearize_theta(self):
        """Get the parameter vector of the network."""
        l_weights = []
        l_biases = []
        for name, t_param in self.named_parameters():
            if "weight" in name:
                l_weights.append(t_param.flatten())
            elif "bias" in name:
                l_biases.append(t_param.flatten())
        theta = torch.concatenate(l_weights + l_biases[:-1])
        return theta

    def forward(self, x):  # pylint: disable=arguments-renamed
        l_x = []
        for module in self:
            x = module(x)
            if isinstance(module, nn.Linear):
                i_layer = len(l_x)
                if self.skip_connections is not None:
                    mask = self.skip_connections[:, 1] == i_layer
                    for j_layer in self.skip_connections[mask, 0]:
                        x += l_x[j_layer]
            elif isinstance(module, nn.ReLU):  # stock all activations
                l_x.append(x)
        return x

    def compute_squaredlayerNorms(self):
        """Computes the squared norms for each layer."""
        l_output_norms = []
        l_input_norms = []
        for layer in self:
            if isinstance(layer, nn.Linear):
                W_out = layer.weight.detach().cpu().numpy()
                if self.bias:
                    W_in = torch.concatenate([layer.weight, layer.bias.reshape(-1,1)], dim=1
                                          ).detach().cpu().numpy()
                else:
                    W_in = W_out
                output_norms = (W_out**2).sum(axis=0)
                input_norms = (W_in**2).sum(axis=1)
                l_output_norms.append(output_norms)
                l_input_norms.append(input_norms)
        l_output_norms = l_output_norms + [None]
        l_input_norms = [None] + l_input_norms
        return l_output_norms, l_input_norms

    def compute_hyperbolae(self, return_norms=False):
        """Computes the hyperbolae for each hidden neuron in the network.

        Args:
            return_norms (bool, optional): Wether or not to return squared output and input norms.
                Defaults to False.

        Returns:
            l_c_k (list): a list of arrays. Each array represent a layer and each array
                stores the hyperbolae value for each neuron it contains.
        """
        l_output_norms, l_input_norms = self.compute_squaredlayerNorms()
        l_c_k = []
        for (o, i) in zip(l_output_norms, l_input_norms):
            if i is None or o is None:
                l_c_k.append(None)
            else:
                c_k = i - o
                l_c_k.append(c_k)
        if return_norms:
            return l_c_k, l_output_norms, l_input_norms
        return l_c_k

    def _init_boundaryMatrix(self):
        """Initialize the boundary matrix B for the MLP."""
        n_rows = sum(self.hidden_channels[:-1])  # number of neurons in hidden layers
        n_weights, n_biases = 0, 0
        for k,v in self.state_dict().items():
            if "weight" in k:
                n_weights += len(v.flatten())
            elif "bias" in k:
                n_biases += len(v.flatten())
        n_biases = n_biases - self.bias*self.hidden_channels[-1]  # last layer neurons aren't in B
        n_skip_connections = np.array(self.hidden_channels)[self.skip_connections[:, 0]
                                                ].sum() if self.skip_connections is not None else 0
        n_cols = n_weights + n_biases + n_skip_connections
        print(f"Number of hidden nodes: {n_rows}",
            f"\nNumber of weights: {n_weights}",
            f"\nNumber of biases: {n_biases}",
            f"\nNumber of skip connections: {n_skip_connections}",
            f"\nTotal number of edges: {n_cols}")
        B = torch.zeros((n_rows, n_cols))
        self.n_biases = n_biases
        self.n_weights = n_weights
        self.n_skip_connections = n_skip_connections
        return B

    def _fill_boundaryMatrix(self):
        """Fills the boundary matrix self.B with 1 and -1 
        corresponding to the topology of the network."""

        # weights
        l_dim = [self.in_channels] + self.hidden_channels
        row = 0
        n_weights_before_layer = np.concatenate([np.array([0]), np.array(l_dim[:-1]
                                                                    )*np.array(l_dim[1:])]).cumsum()
        for idx_hidden_layer, l_neurons in enumerate(l_dim[1:-1]):
            for idx_neuron in range(l_neurons):
                start_plus = (n_weights_before_layer[idx_hidden_layer]
                              + idx_neuron*l_dim[idx_hidden_layer])
                start_minus = n_weights_before_layer[idx_hidden_layer+1] + idx_neuron
                self.B[row, start_plus:start_plus+l_dim[idx_hidden_layer]] = 1
                self.B[row,
                       [start_minus+i*l_dim[idx_hidden_layer+1]
                        for i in range(l_dim[idx_hidden_layer+2])]
                       ] = -1
                row += 1

        # biases
        if self.bias:
            self.B[:,self.n_weights:self.n_weights+self.n_biases] = torch.eye(self.n_biases)

        # skip connections
        if self.skip_connections is not None:
            for a_connection in self.skip_connections:
                start_row = sum(self.hidden_channels[:a_connection[0]])
                end_row = start_row + self.hidden_channels[a_connection[0]]
                start_col = self.n_weights+self.n_biases+start_row
                end_col = self.n_weights+self.n_biases+end_row
                self.B[start_row:end_row, start_col:end_col] = - torch.eye(end_row-start_row)
                start_row = sum(self.hidden_channels[:a_connection[1]])
                end_row = start_row + self.hidden_channels[a_connection[1]]
                self.B[start_row:end_row, start_col:end_col] = torch.eye(end_row-start_row)

    def send_toCone(self, norm_target):
        """Sends the parameter to the cone, arbitrarily i.e. it will not be
        an observationally equivalent parametrization."""
        norm_target = 0.1
        state_dict = self.state_dict()
        l_weight_keys = [x for x in state_dict.keys() if "weight" in x]
        l_bias_keys = [x for x in state_dict.keys() if "bias" in x]
        assert self.bias == (len(l_bias_keys)>0)

        l_output_norms, l_input_norms = self.compute_squaredlayerNorms()

        # fix input norm of first layer
        alpha = np.sqrt(norm_target/l_input_norms[1])
        state_dict["0.weight"] = state_dict["0.weight"] * alpha.reshape(-1, 1)
        if self.bias:
            state_dict[l_bias_keys[0]] = state_dict[l_bias_keys[0]] * alpha

        # put all biases beyond input layer to 0
        for k in l_bias_keys[1:]:
            state_dict[k] = torch.zeros_like(state_dict[k])

        # fix output norms of the last layer
        key_last_weight = l_weight_keys[-1]
        alpha = np.sqrt(norm_target/l_output_norms[len(l_weight_keys)-1])
        state_dict[key_last_weight] = state_dict[key_last_weight]*alpha

        # replace all hidden to hidden matrices with doubly stochastic matrices
        for key_weight in l_weight_keys[1:-1]:
            # change the matrix to a doubly stochastic matrix
            M = state_dict[key_weight]
            n_rows, n_cols = M.shape
            if not n_rows==n_cols:
                raise ValueError("Hidden to hidden matrix is not square")
            M_doubly_stochastic = generate_randDoublyStochastic(n_rows)
            W = (torch.randint(0,2,(n_rows,n_cols))*2-1)*np.sqrt(M_doubly_stochastic*norm_target)
            # W = np.sqrt(M_doubly_stochastic*norm_target)
            state_dict[key_weight] = W

        self.load_state_dict(state_dict)

class Shallow(MLP):
    """This class defines a simple shallow ReLU network
    with hyperbolae utilities"""

    def __init__(self, n_hidden=2, n_input=2):
        super().__init__(in_channels=n_input, hidden_channels=[n_hidden,1]
                         , activation_layer=nn.ReLU, bias=False)

    def rescale_hidden(self, alpha):
        """Scales the output weights of hidden neurons by alpha and the input weights by 1/alpha.

        Args:
            alpha (array): 
        """
        list(self.parameters())[1].data = list(self.parameters())[1].data*torch.tensor(alpha).float().to(device)  # pylint: disable=line-too-long
        list(self.parameters())[0].data = torch.diag(torch.tensor(1/alpha).float().to(device))@list(self.parameters())[0].data  # pylint: disable=line-too-long

    def compute_alpha(self, c):
        """Computes the alpha vector necessary to reach the hyperbolae vector c.

        Args:
            c (array): the hyperbolae vector to reach, for each hidden neuron.

        Returns:
            alpha (array)
        """
        l_output_norms, l_input_norms = self.compute_squaredlayerNorms()
        alpha = np.sqrt(
                (
                    -c
                    + np.sqrt(4 * l_output_norms[1] * l_input_norms[1] + (-c)**2)
                )
                / (2 * l_output_norms[1])
        )
        return alpha

    def set_finalLayerBits(self, l_sgn):
        """Sets the sign of the output weights of the last layer.

        Args:
            l_sgn (list[+-1]): the sign to set for each output weight.
        """
        self.linear2.weight.data = torch.abs(self.linear2.weight.data) * torch.sign(torch.tensor(l_sgn)).float().to(device)  # pylint: disable=line-too-long

def generate_randDoublyStochastic(n):
    """Generates a random doubly stochastic matrix using the Ginibre ensemble."""
    # Step 1: Generate a random complex matrix (Ginibre ensemble)
    real_part = np.random.normal(0, 1, (n, n))
    imag_part = np.random.normal(0, 1, (n, n))
    ginibre_matrix = real_part + 1j * imag_part

    # Step 2: Get a unitary matrix via QR decomposition
    q, r = qr(ginibre_matrix)
    # Normalize Q to ensure unitarity (make determinant positive)
    d = np.diagonal(r)
    ph = d / np.abs(d)
    q = q * ph

    # Step 3: Take squared magnitudes to get a doubly stochastic matrix
    doubly_stochastic = np.abs(q)**2

    return doubly_stochastic
