"""Create invertible mixing networks. Code bootstrapped from Kugelgen et al., 2021 and Lyu et al., 2022"""

import numpy as np
import torch
from torch import nn
from scipy.stats import ortho_group
from typing import Union
from typing_extensions import Literal

def invertible_network_hyvarinen(n: int = 20, 
                                    n_layers: int = 2, 
                                    n_iter_cond_thresh: int = 10000,
                                    cond_thresh_ratio: float = 0.5,
                                    weight_matrix_init: Union[Literal["pcl"], Literal["rvs"]] = 'pcl',
                                    act_fct: Union[Literal["relu"], Literal["leaky_relu"], Literal["elu"],
                                            Literal["smooth_leaky_relu"], Literal["softplus"]] = 'leaky_relu'):
    """
    Create an (approximately) invertible mixing network based on an MLP.
    Based on the mixing code by Hyvarinen et al.

    Args:
        n: Dimensionality of the input and output data
        n_layers: Number of layers in the MLP.
        n_iter_cond_thresh: How many random matrices to use as a pool to find weights.
        cond_thresh_ratio: Relative threshold how much the invertibility
            (based on the condition number) can be violated in each layer.
        weight_matrix_init: How to initialize the weight matrices.
        act_fct: Activation function for hidden layers.
    """

    class SmoothLeakyReLU(nn.Module):
        def __init__(self, alpha=0.2):
            super().__init__()
            self.alpha = alpha

        def forward(self, x):
            return self.alpha * x + (1 - self.alpha) * torch.log(1 + torch.exp(x))

    def get_act_fct(act_fct):
        if act_fct == 'relu':
            return torch.nn.ReLU, {}, 1
        if act_fct == 'leaky_relu':
            return torch.nn.LeakyReLU, {'negative_slope': 0.2}, 1
        elif act_fct == 'elu':
            return torch.nn.ELU, {'alpha': 1.0}, 1
        elif act_fct == 'max_out':
            raise NotImplemented()
        elif act_fct == 'smooth_leaky_relu':
            return SmoothLeakyReLU, {'alpha': 0.2}, 1
        elif act_fct == 'softplus':
            return torch.nn.Softplus, {'beta': 1}, 1
        else:
            raise Exception(f'activation function {act_fct} not defined.')

    layers = []
    act_fct, act_kwargs, act_fac = get_act_fct(act_fct)

    # Subfuction to normalize mixing matrix
    def l2_normalize(Amat, axis=0):
        # axis: 0=column-normalization, 1=row-normalization
        l2norm = np.sqrt(np.sum(Amat * Amat, axis))
        Amat = Amat / l2norm
        return Amat

    condList = np.zeros([n_iter_cond_thresh])
    if weight_matrix_init == 'pcl':
        for i in range(n_iter_cond_thresh):
            A = np.random.uniform(-1, 1, [n, n])
            A = l2_normalize(A, axis=0)
            condList[i] = np.linalg.cond(A)
        condList.sort()  # Ascending order
    condThresh = condList[int(n_iter_cond_thresh * cond_thresh_ratio)]

    for i in range(n_layers):

        lin_layer = nn.Linear(n, n, bias=False)

        if weight_matrix_init == 'pcl':
            condA = condThresh + 1
            while condA > condThresh:
                weight_matrix = np.random.uniform(-1, 1, (n, n))
                weight_matrix = l2_normalize(weight_matrix, axis=0)

                condA = np.linalg.cond(weight_matrix)
            lin_layer.weight.data = torch.tensor(weight_matrix, dtype=torch.float32)

        elif weight_matrix_init == 'rvs':
            weight_matrix = ortho_group.rvs(n)
            lin_layer.weight.data = torch.tensor(weight_matrix, dtype=torch.float32)
        elif weight_matrix_init == 'expand':
            pass
        else:
            raise Exception(f'weight matrix {weight_matrix_init} not implemented')

        layers.append(lin_layer)

        if i < n_layers - 1:
            layers.append(act_fct(**act_kwargs))

    mixing_net = nn.Sequential(*layers)

    # fix parameters
    for p in mixing_net.parameters():
        p.requires_grad = False

    return mixing_net