"""
This module provides utility functions for graph processing."""
from typing import Tuple
import sys
import copy
import networkx as nx
import numpy as np
import cvxpy as cp
import torch
from scipy import sparse as sp_sparse
import matplotlib.pyplot as plt

from Code.lp_optim_utils import solve_mcc_log_barrier_cvxpy
from Code.post_processing_utils import get_normalized_shap_vals

def get_adj_matrix_backward(agg_attentions:torch.Tensor) -> Tuple[np.ndarray, np.ndarray, float, np.ndarray, np.ndarray]:
    """
    Calculate the adjacency matrix for a given set of aggregated attentions.
    
    Args:
        agg_attentions (torch.Tensor): A tensor containing the aggregated attentions. 
        The shape of the tensor is (num_layers, len_tokens, len_tokens).

    Returns:
        Tuple[np.ndarray, np.ndarray, float, np.ndarray]: A tuple containing the following:
            - adj_mat (np.ndarray): The adjacency matrix. 
            It has the shape ((n_layers+1) * len_tokens + 2, (num_layers+1) * len_tokens + 2).
            
            - scaled_adj_mat (np.ndarray): 
            The scaled adjacency matrix. It has the same shape as adj_mat, but with integer values.
            
            - max_flow_val (float): 
            The maximum flow value from the source to the target.
            - max_flow_arr (np.ndarray): The flow array with maximum flow values. 
            It has the same shape as scaled_adj_mat.
            
            - shapley_vals_layerwise (np.ndarray): 
            The Shapley values for each layer, represented as a matrix of shape (num_layers, len_tokens).
    """
    
    num_layers, len_tokens = agg_attentions.shape[0], agg_attentions.shape[1]

    # Set the threshold of printing
    np.set_printoptions(linewidth=sys.maxsize)

    # Find the smallest value that is greater than 0 in the agg_attentions tensor.
    mask = torch.gt(agg_attentions, 0)
    beta_min = agg_attentions[mask].min().detach().numpy()
    print(beta_min)
    
    beta = -np.floor(np.log10(beta_min))
    scale_factor = 10**beta

    agg_attentions = agg_attentions.detach().numpy()
    
    # Initialize the adjacency matrix with the shape of ((num_layers+1) * len_tokens + 2) x ((num_layers+1) * len_tokens + 2)
    adj_mat = np.zeros(((num_layers+1) * len_tokens + 2, (num_layers+1) * len_tokens + 2))

    inf_cap = len_tokens*1 # Source and sink: infinity capacity ;)
    # inf_cap is always equal to len_tokens*1!

    # Fill first layer -> target
    for i in range(len_tokens):
        adj_mat[i+1][0] = inf_cap

    # Fill source -> last layer
    for i in range(len_tokens):
        adj_mat[-1][-i-2] = inf_cap

    # Fill the rest of the adjacency matrix
    for layer in range(0, num_layers):
        start = len_tokens*(layer)+1
        mid = len_tokens*(layer+1)+1
        end = len_tokens*(layer+2)+1
        adj_mat[mid:end, start:mid] = agg_attentions[layer]
    
    # Scale the adjacency matrix and round it to the nearest integer
    scaled_adj_mat = (scale_factor * adj_mat).astype(int)

    # Calculate the maximum flow from the source to the target
    src = (num_layers+1)*len_tokens+1
    target = 0

    graph = sp_sparse.csr_matrix(scaled_adj_mat)
    maxflow = sp_sparse.csgraph.maximum_flow(graph, src, target, method='edmonds_karp')

    max_flow_val = maxflow.flow_value * (scale_factor)**-1

    max_flow_arr = maxflow.flow.toarray() * (scale_factor)**-1
    max_flow_arr[max_flow_arr<0] = 0

    # Calculate the Shapley values for each layer
    shapley_vals = np.sum(max_flow_arr, axis=1)
    shapley_vals_layerwise = shapley_vals[1:-1].reshape((num_layers+1, len_tokens))
    normalized_shapley_vals_layerwise = get_normalized_shap_vals(shapley_vals_layerwise)
    
    bw_graph_info = {
        'adj_mat': adj_mat,
        'scaled_adj_mat': scaled_adj_mat,
        'max_flow_val': max_flow_val,
        'max_flow_arr': max_flow_arr,
        "shapley_vals": shapley_vals,
        'shapley_vals_layerwise': shapley_vals_layerwise,
        'normalized_shapley_vals_layerwise': normalized_shapley_vals_layerwise
    }

    return bw_graph_info


def get_adj_matrix_forward(agg_attentions:torch.Tensor) -> Tuple[np.ndarray, np.ndarray, float, np.ndarray, np.ndarray]:
    """
    Calculate the adjacency matrix for a given set of aggregated attentions.
    
    Args:
        agg_attentions (torch.Tensor): A tensor containing the aggregated attentions. 
        The shape of the tensor is (num_layers, len_tokens, len_tokens).

    Returns:
        Tuple[np.ndarray, np.ndarray, float, np.ndarray]: A tuple containing the following:
            - adj_mat (np.ndarray): The adjacency matrix. 
            It has the shape ((n_layers+1) * len_tokens + 2, (num_layers+1) * len_tokens + 2).
            
            - scaled_adj_mat (np.ndarray): 
            The scaled adjacency matrix. It has the same shape as adj_mat, but with integer values.
            
            - max_flow_val (float): 
            The maximum flow value from the source to the target.
            - max_flow_arr (np.ndarray): The flow array with maximum flow values. 
            It has the same shape as scaled_adj_mat.
            
            - shapley_vals_layerwise (np.ndarray): 
            The Shapley values for each layer, represented as a matrix of shape (num_layers, len_tokens).
    """
    
    num_layers, len_tokens = agg_attentions.shape[0], agg_attentions.shape[1]

    # Set the threshold of printing
    np.set_printoptions(linewidth=sys.maxsize)

    # Find the smallest value that is greater than 0 in the agg_attentions tensor.
    mask = torch.gt(agg_attentions, 0)
    beta_min = agg_attentions[mask].min().detach().numpy()
    print(beta_min)
    
    beta = -np.floor(np.log10(beta_min))
    scale_factor = 10**beta
    
    agg_attentions = agg_attentions.detach().numpy()
    
    # Initialize the adjacency matrix with the shape of ((num_layers+1) * len_tokens + 2) x ((num_layers+1) * len_tokens + 2)
    adj_mat = np.zeros(((num_layers+1) * len_tokens + 2, (num_layers+1) * len_tokens + 2))

    inf_cap = len_tokens*1 # Source and sink: infinity capacity ;)
    # inf_cap is always equal to len_tokens*1!

    # Fill source -> First Layer
    for i in range(len_tokens):
        adj_mat[0][i+1] = inf_cap

    # Fill Last Layer -> target
    for i in range(len_tokens):
        adj_mat[-i-2][-1] = inf_cap


    # Fill the rest of the adjacency matrix
    for layer in range(0, num_layers):
        start = len_tokens*(layer)+1
        mid = len_tokens*(layer+1)+1
        end = len_tokens*(layer+2)+1
        adj_mat[start:mid, mid:end] = agg_attentions[layer].transpose(1,0)


    # Scale the adjacency matrix and round it to the nearest integer
    scaled_adj_mat = (scale_factor * adj_mat).astype(int)

    # Calculate the maximum flow from the source to the target
    src = 0
    target = (num_layers+1)*len_tokens + 1

    graph = sp_sparse.csr_matrix(scaled_adj_mat)
    maxflow = sp_sparse.csgraph.maximum_flow(graph, src, target, method='edmonds_karp')

    max_flow_val = maxflow.flow_value * (scale_factor)**-1

    max_flow_arr = maxflow.flow.toarray() * (scale_factor)**-1
    max_flow_arr[max_flow_arr<0] = 0

    # Calculate the Shapley values for each layer
    shapley_vals = np.sum(max_flow_arr, axis=1)
    shapley_vals_layerwise = shapley_vals[1:-1].reshape((num_layers+1, len_tokens))
    normalized_shapley_vals_layerwise = get_normalized_shap_vals(shapley_vals_layerwise)
    
    fw_graph_info = {
        'adj_mat': adj_mat,
        'scaled_adj_mat': scaled_adj_mat,
        'max_flow_val': max_flow_val,
        'max_flow_arr': max_flow_arr,
        "shapley_vals": shapley_vals,
        'shapley_vals_layerwise': shapley_vals_layerwise,
        'normalized_shapley_vals_layerwise': normalized_shapley_vals_layerwise
    }

    return fw_graph_info


def create_and_draw_graph(capacity_matrix, sequence_length:int, 
                          save_path:str, width:float, height:float, dpi:int,
                          backward:bool=True, color:str="skyblue",
                          change_pos:bool=False, label_pos:float = 0.8):
    """
    This function creates a directed graph from a capacity matrix and draws it.

    Parameters:
    capacity_matrix (numpy.ndarray): The capacity matrix representing the graph.
    sequence_length (int): The sequence length for layering the nodes.

    Returns:
    G (networkx.DiGraph): The created directed graph.
    """
    # Create a directed graph
    G = nx.DiGraph()

    # round the capacity matrix to 4 decimal places
    capacity_matrix = capacity_matrix.round(4)
    
    # Add nodes (if needed)
    nodes = range(len(capacity_matrix))
    G.add_nodes_from(nodes)

    # Add edges with capacities from the capacity matrix
    i_indices, j_indices = np.where(capacity_matrix > 0)

    for i, j in zip(i_indices, j_indices):
        G.add_edge(i, j, capacity=capacity_matrix[i][j])

    # Add an attribute to nodes to indicate which layer they belong to
    G.nodes[0]["layer"] = 0
    G.nodes[capacity_matrix.shape[0]-1]["layer"] = sequence_length+1
    
    for i in range(1, capacity_matrix.shape[0]):
        G.nodes[i]["layer"] = ((i-1) // sequence_length)+1

    # print the layer of each node
    for node, data in G.nodes(data=True):
        print("Node:", node, "Layer:", data["layer"])

    fig, ax = plt.subplots(figsize=(14, 9))

    # Plot the graph
    # Create a bipartite layout
    pos = nx.multipartite_layout(G, subset_key="layer", align="vertical")
    
    if change_pos:
        # Adjust the positions slightly to avoid label conflicts
        pos = {node: (x+0.0005, y-0.0005) for node, (x, y) in pos.items()}

    
    # Define a layout for the nodes
    labels = nx.get_edge_attributes(G, 'capacity')
    
    # Create a dictionary of node names in LaTeX format
    node_names = {node: r'$v_{' + str(node) + '}$' for node in G.nodes()}
    
    # Change the labels of the first and last nodes
    if backward:
        node_names[0] = r'$v_{t}$'
        node_names[len(capacity_matrix)-1] = r'$v_{s}$'
    else:
        node_names[0] = r'$v_{s}$'
        node_names[len(capacity_matrix)-1] = r'$v_{t}$'
    
    # Draw the graph with labels
    nx.draw(G, pos, labels=node_names, node_size=300, node_color=color)
    nx.draw_networkx_edge_labels(G, pos, edge_labels=labels, verticalalignment='top', horizontalalignment='right', label_pos=label_pos, font_size=7)

    # change the size of the figure
    # plt.title("Information Flow (Layered Layout)")
    plt.axis('off')
    
    # Save the figure
    fig.set_size_inches(width / 2.54, height / 2.54)  # Convert cm to inches
    plt.savefig(save_path, format='pdf', dpi=dpi)  # Change the path to your desired path

    plt.show()

    return G, capacity_matrix

def create_and_draw_graph_without_labels(capacity_matrix, sequence_length:int, save_path:str, width:float, 
                                         height:float, dpi:int, backward:bool=True, color:str="skyblue"):
    """
    This function creates a directed graph from a capacity matrix and draws it.

    Parameters:
    capacity_matrix (numpy.ndarray): The capacity matrix representing the graph.
    sequence_length (int): The sequence length for layering the nodes.

    Returns:
    G (networkx.DiGraph): The created directed graph.
    """
    # Create a directed graph
    G = nx.DiGraph()

    # round the capacity matrix to 4 decimal places
    capacity_matrix = capacity_matrix.round(4)
    
    # Add nodes (if needed)
    nodes = range(len(capacity_matrix))
    G.add_nodes_from(nodes)

    # Add edges with capacities from the capacity matrix
    # Assuming capacity_matrix is a numpy array
    i_indices, j_indices = np.where(capacity_matrix > 0)

    for i, j in zip(i_indices, j_indices):
        G.add_edge(i, j, capacity=capacity_matrix[i][j])

    # Add an attribute to nodes to indicate which layer they belong to
    G.nodes[0]["layer"] = 0
    G.nodes[capacity_matrix.shape[0]-1]["layer"] = sequence_length+1
    
    for i in range(1, capacity_matrix.shape[0]):
        G.nodes[i]["layer"] = ((i-1) // sequence_length)+1

    # print the layer of each node
    for node, data in G.nodes(data=True):
        print("Node:", node, "Layer:", data["layer"])

    fig, ax = plt.subplots(figsize=(14, 9))

    # Plot the graph
    # Create a bipartite layout
    pos = nx.multipartite_layout(G, subset_key="layer", align="vertical")
    
    
    # Create a dictionary of node names in LaTeX format
    node_names = {node: r'$v_{' + str(node) + '}$' for node in G.nodes()}
    
    # Change the labels of the first and last nodes
    node_names[0] = r'$v_{s}$'
    node_names[len(capacity_matrix)-1] = r'$v_{t}$'
    
    # Change the labels of the first and last nodes
    if backward:
        node_names[0] = r'$v_{t}$'
        node_names[len(capacity_matrix)-1] = r'$v_{s}$'
    else:
        node_names[0] = r'$v_{s}$'
        node_names[len(capacity_matrix)-1] = r'$v_{t}$'
    
    # Draw the graph with labels
    nx.draw(G, pos, labels=node_names, node_size=400, node_color=color)

    # Draw a rectangle around each partite
    layers = nx.get_node_attributes(G, 'layer')
    offset = 0.05  # Increase this value to increase the size of the rectangle
    
    for layer in set(layers.values()):
        xs = [x for (x, y), node_layer in zip(pos.values(), layers.values()) if node_layer == layer]
        ys = [y for (x, y), node_layer in zip(pos.values(), layers.values()) if node_layer == layer]
        ax.add_patch(plt.Rectangle((min(xs)-offset, min(ys)-offset), max(xs)-min(xs)+2*offset, max(ys)-min(ys)+2*offset, fill=None, edgecolor=(0.871, 0.047, 0.384, 1.0), linestyle='--', linewidth=2))
        
        # Add layer number before each layer from the second layer to one layer before the last layer
        # if 0 < layer < max(layers.values()):
        #     if backward:
        #         plt.text(min(xs)-1.1*offset, min(ys)-1.1*offset, r'$L$'+f'={layer}', ha='right', fontweight='normal', fontstyle='italic', color=(0.98, 0.165, 0.333, 1.0), fontsize=10)

        #     else:
        #         plt.text(min(xs)-1.1*offset, min(ys)-1.1*offset, r'$L$'+f'={max(layers.values())-layer}', ha='right', fontweight='normal', fontstyle='italic', color=(0.98, 0.165, 0.333, 1.0), fontsize=10)
        
         # Add text between two layers
        if 0 < layer < max(layers.values()) - 1:
            next_layer_xs = [x for (x, y), node_layer in zip(pos.values(), layers.values()) if node_layer == layer + 1]
            mid_x = (min(xs) + min(next_layer_xs)) / 2
            
            if backward:
                plt.text(mid_x, min(ys) - offset, r'$\bar{A}_{'+'['+str(layer)+', :, :]'+'}$', ha='center', va='top', fontweight='normal', fontstyle='italic', color=(0.0, 0.0, 0.0, 1.0), fontsize=12)
            else:
                # plt.text(mid_x, min(ys) - offset, r'$\bar{A}_{'+'['+str(max(layers.values())-layer-1)+', :, :]'+'}$', ha='center', va='top', fontweight='normal', fontstyle='italic', color=(0.0, 0.0, 0.0, 1.0), fontsize=16)
                plt.text(mid_x, min(ys) - offset, r'$\bar{A}^{T}_{'+'['+str(layer)+', :, :]'+'}$', ha='center', va='top', fontweight='normal', fontstyle='italic', color=(0.0, 0.0, 0.0, 1.0), fontsize=12)

    # plt.title("Information Flow")
    plt.axis('off')
    
    # Save the figure
    fig.set_size_inches(width / 2.54, height / 2.54)  # Convert cm to inches
    plt.savefig(save_path, format='pdf', dpi=dpi, bbox_inches='tight')  # Change the path to your desired path

    plt.show()

    return G, capacity_matrix



class MCCInfo:
    """
    This class represents the maximum circulation problem.
    """
    def __init__(self, capacity_matrix: np.ndarray, sequence_length: int):
        """
        Initializes a new instance of the class.

        Args:
            capacity_matrix (np.ndarray): The capacity matrix used to initialize the instance.

        Returns:
            None
        """
        self.capacity_matrix = capacity_matrix
        self.sequence_length = sequence_length
        
    def get_cost_matrix(self, source: int, target: int) -> np.ndarray:
        """
        Get the cost matrix for the circulation problem.

        Args:
        - capacity_matrix: Capacity matrix
        - source: Source node
        - target: Target node

        Returns:
        - Cost matrix
        """
        capacity_matrix = self.capacity_matrix
        cost_matrix = np.zeros(capacity_matrix.shape)
        cost_matrix[target, source] = -1
        return cost_matrix
    
    def add_ts_to_capacity_matrix(self, capacity_matrix, source: int, target: int) -> None:
        """
        Update the capacity matrix.

        Args:
        - capacity_matrix: New capacity matrix
        """
        updated_capacity_matrix = copy.deepcopy(capacity_matrix)
        sequence_length = self.sequence_length
        updated_capacity_matrix[target, source] = sequence_length*sequence_length
        
        return updated_capacity_matrix

    def create_incidence_matrix(self, backward: bool) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        """
        Generate the incidence matrix, cost vector, and capacity vector based on the given capacity matrix and cost matrix.

        Parameters:
            capacity_matrix (np.ndarray): The matrix representing the capacity of each edge in the graph.
            cost_matrix (np.ndarray): The matrix representing the cost of each edge in the graph.

        Returns:
            Tuple[np.ndarray, np.ndarray, np.ndarray]: A tuple containing the capacity vector, cost vector, and the incidence matrix.
        """
        init_capacity_matrix = self.capacity_matrix
        
        if backward:
            source, target = init_capacity_matrix.shape[0]-1, 0
            cost_matrix = self.get_cost_matrix(source=source, target=target)
            capacity_matrix = self.add_ts_to_capacity_matrix(init_capacity_matrix, source=source, target=target)
            
        else:
            source, target = 0, init_capacity_matrix.shape[0]-1
            cost_matrix = self.get_cost_matrix(source=source, target=target)
            capacity_matrix = self.add_ts_to_capacity_matrix(init_capacity_matrix, source=source, target=target)
        
        # Chceck that shape of capacity matrix and cost matrix are the same; if not, raise an error
        if capacity_matrix.shape != cost_matrix.shape:
            raise ValueError("The shape of capacity matrix and cost matrix must be the same.")
        
        # Identify the edges from the capacity matrix
        i_indices, j_indices = np.where(capacity_matrix > 0)
        edges = list(zip(i_indices, j_indices))

        num_vertices = capacity_matrix.shape[0]
        num_edges = len(edges)

        # Create the incidence matrix
        incidence_matrix = np.zeros((num_edges, num_vertices))
        
        # Create the cost vector
        cost_vector = np.zeros(num_edges)
        
        # Create the capacity vector
        capacity_vector = np.zeros(num_edges)

        for idx, (i, j) in enumerate(edges):
            incidence_matrix[idx, i] = -1
            incidence_matrix[idx, j] = 1
            
            cost_vector[idx] = cost_matrix[i, j]
            capacity_vector[idx] = capacity_matrix[i, j]
            
        mcc_info = {
            "capacity_vector": capacity_vector,
            "cost_vector": cost_vector,
            "incidence_matrix": incidence_matrix,
            "capacity_matrix": capacity_matrix,
            "cost_matrix": cost_matrix
        }

        return mcc_info


def flow_to_matrix(flow, incidence_matrix) -> np.ndarray:
    """ 
    Convert the flow vector to the flow matrix.

    Parameters:
    - flow: Flow vector
    - incidence_matrix: Incidence matrix

    Returns:
    - flow_matrix: Flow matrix
    """
    num_vertices = incidence_matrix.shape[1]
    flow_matrix = np.zeros((num_vertices, num_vertices))

    for idx, value in enumerate(flow):
        i, j = np.where(incidence_matrix[idx] == -1)[0][0], np.where(incidence_matrix[idx] == 1)[0][0]
        flow_matrix[i, j] = value

    return flow_matrix


def get_shap_info(attribution_flow, len_tokens:int, backward=True, mu=1e-14, solver = cp.ECOS) -> dict:
    """
    Generate the Shapley information for a given attribution flow.

    Parameters:
    - attribution_flow (dict): A dictionary containing the attribution flow information.
    - len_tokens (int): The length of the tokens.
    - backward (bool, optional): Flag indicating whether to compute the Shapley information backward. Defaults to True.
    - mu (float, optional): The regularization parameter. Defaults to 1e-14.
    - solver (cp.Solver, optional): The solver to use for optimization. Defaults to cp.ECOS.

    Returns:
    - shap_info (dict): A dictionary containing the Shapley information.
        - mcc_model_info (dict): A dictionary containing the MCC model information.
            - incidence_matrix (np.ndarray): The incidence matrix of the MCC model.
            - cost_vector (np.ndarray): The cost vector of the MCC model.
            - capacity_vector (np.ndarray): The capacity vector of the MCC model.
        - flow_matrix (np.ndarray): The flow matrix.
        - flow (np.ndarray): The flow vector.
        - shapley_vals (np.ndarray): The Shapley values.
        - shapley_vals_layerwise (np.ndarray): The layerwise Shapley values.
        - normalized_shapley_vals_layerwise (np.ndarray): The normalized layerwise Shapley values.
    """
    mcc_model = MCCInfo(attribution_flow.get('adj_mat'), len_tokens)
    mcc_model_info = mcc_model.create_incidence_matrix(backward=backward)

    B = mcc_model_info.get('incidence_matrix')
    c = mcc_model_info.get('cost_vector')
    u = mcc_model_info.get('capacity_vector')
    l = np.zeros(c.shape)

    _, flow = solve_mcc_log_barrier_cvxpy(B, c, l, u, mu, solver=solver)
    flow_matrix = flow_to_matrix(flow, mcc_model_info.get('incidence_matrix'))
    
    shapley_vals = np.sum(flow_matrix, axis=1)
    shapley_vals_layerwise = shapley_vals[1:-1].reshape((-1, len_tokens))
    normalized_shapley_vals_layerwise = get_normalized_shap_vals(shapley_vals_layerwise)
    
    shap_info = {
        'mcc_model_info': mcc_model_info,
        'flow_matrix': flow_matrix,
        'flow': flow,
        'shapley_vals': shapley_vals,
        'shapley_vals_layerwise': shapley_vals_layerwise,
        'normalized_shapley_vals_layerwise': normalized_shapley_vals_layerwise
    }
    return shap_info