import matplotlib.pyplot as plt 
import numpy as np
from graspologic.simulations import sbm
from sklearn import metrics
from typing import Union
import random as rd

def plot_graph(graph:np.array):
    '''
    Plot the adjacency matrix of the graph

    Args: 
        graph (np.array): The adjancency matrix of the graph
    '''
    if len(graph.shape) ==3:
        for i in range(graph.shape[0]):
            figure = plt.figure()
            plt.imshow(graph[i],cmap="summer", vmin=0, vmax=1)
    elif len(graph.shape) ==2:
        figure = plt.figure()
        plt.imshow(graph,cmap="summer", vmin=0, vmax=1)
    else:
        raise('You dont give a graph')


def join_all_layer(*args):
    '''
    When you have separated layers, represented by np.array. This function join
    all the layers to contruct one multiplex graph.

    Returns:
        multiplex_graph (np.array): The Multiplex graph that represent all the layers.
    '''
    multiplex_graph = np.concatenate(args,axis=0)
    return multiplex_graph

def generate_layers(n_:np.array,
                    p_:np.array,
                    nb_layers:np.int32=1):
    '''
    Generate a random layer from sbm distribution. We use the librabry graspologic.
    To install the library, please use 'pip install graspologic'
    
    Args:

        n_ (np.array): Set the number of inidividuals in each block
        p_ (np.array): Defines the parameter of distribution between the block
        nb_layers (np.int32): Set the number of layer that will generated from the same distribution

    Retunrs:
        
        layers_ (np.array): Multiplex graph from the same distribution
    '''
    assert n_.shape[0] == p_.shape[0], 'must the number of group be the same for both n_ and p_ parameters'
    n_vertices = n_.sum()
    layers_ = np.zeros((nb_layers,n_vertices,n_vertices))
    for i in range(nb_layers):
        layers_[i] = sbm(n=n_, p=p_)
    return layers_


def nmi(label_truth:Union[list, np.ndarray],
        label_predict: Union[list, np.ndarray]
        ):

    '''
    Compute the Normalized Mutuelle information between two set of labels 

    Args:
        label_truth (Union[list, np.ndarray]): The set of truth labels
        label_predict (Union[list, np.ndarray]): The set of predict labels 
    
    Returns:
        (np.flot64): The NMi performance between both label sets
     '''
    return metrics.normalized_mutual_info_score(label_truth,label_predict)*100
    

def _argmax_function(vec:np.array,
                     axis_:np.int32= 1):
    '''
    Compute the argmax of an np.array, and then computes its dummy representation

    Args:

        vec (np.array): an array
        axis_ (np.int32): the axis where we want to compute its argmax
    
    Returns:

        vec (np.array): dummy representation of the argmax vector
    '''
    vec_argmax = vec.argmax(axis=axis_)
    vec = np.zeros((vec.shape[0],vec.shape[1]),dtype=np.float128)
    vec[np.arange(vec.shape[0]),vec_argmax] = 1 
    return vec

def bar_plot_figures(values:list,
                     labels:list,
                     x_axis_label:str,
                     y_axis_label:str,
                     fontsize:int=20):

    '''
    Plot the performance between of Bar plot

    Args:
        values (list): values that will represent the bar
        labels (list): the name of each method that represent the value
        x_label (str): the label of x_axis
        y_label (str): the label of y_axis
        fontsize (int): the fontsize of the labels
    '''
    figure,axis = plt.subplots(1,1,figsize=(10,10))
    axis.bar(labels,values)
    axis.set_xlabel(x_axis_label,fontsize=fontsize)
    axis.set_ylabel(y_axis_label,fontsize=fontsize)


def random_st(type_:str,
              size:tuple,
              random_func=rd,
              a:int=0,
              b:int=1):
    '''
    Wrapper function for random random assingmenet variable intialization
    
    Args:
        type_ (str): The type of the random generator to use, eg: uniform (used in this case)
        size (tuple): the size of random vector to generate
        random_func : the generator of random values
        a (int) : the lower bound of values to be generated
        b (int): the upper bound of values to be generated

    Returns: 
        vector (np.array): an Array of size 'size' that contains a random values
    get the random function that suite your model
    the interv
    '''
    def one_iter(x):
        x = getattr(random_func,type_)(a,b)
        return x
    vector = np.zeros((size[0],size[1]))
    vector_func = np.vectorize(one_iter)
    vector = vector_func(vector)
    return vector