import numpy as np
import networkx as nx
from grakel import GraphKernel

def graphlet_kernel(graphs, k=3):
    """
    Compute the graphlet kernel matrix for a list of graphs.
    
    Args:
        graphs (list): List of networkx graphs.
        k (int): Size of the graphlets.
        
    Returns:
        kernel_matrix (numpy.ndarray): Graphlet kernel matrix.
    """
    kernel = GraphKernel(kernel="graphlet_sampling", k=k)
    kernel_matrix = kernel.fit_transform(graphs)
    return kernel_matrix

def wl_subtree_kernel(graphs, h=3):
    """
    Compute the Weisfeiler-Lehman subtree kernel matrix for a list of graphs.
    
    Args:
        graphs (list): List of networkx graphs.
        h (int): Number of iterations for the Weisfeiler-Lehman algorithm.
        
    Returns:
        kernel_matrix (numpy.ndarray): Weisfeiler-Lehman subtree kernel matrix.
    """
    kernel = GraphKernel(kernel="weisfeiler_lehman", n_iter=h)
    kernel_matrix = kernel.fit_transform(graphs)
    return kernel_matrix

def shortest_path_kernel(graphs):
    """
    Compute the shortest path kernel matrix for a list of graphs.
    
    Args:
        graphs (list): List of networkx graphs.
        
    Returns:
        kernel_matrix (numpy.ndarray): Shortest path kernel matrix.
    """
    kernel = GraphKernel(kernel="shortest_path")
    kernel_matrix = kernel.fit_transform(graphs)
    return kernel_matrix

def get_kernel_matrix(graphs, pattern_subgraphs, weights):
    """
    Compute the kernel matrix for a list of graphs using different pattern subgraphs.
    
    Args:
        graphs (list): List of networkx graphs.
        pattern_subgraphs (dict): Dictionary containing sampled subgraphs for each pattern.
        weights (list): Weights for each pattern.
        
    Returns:
        kernel_matrix (numpy.ndarray): Combined kernel matrix.
    """
    num_patterns = len(pattern_subgraphs)
    num_graphs = len(graphs)
    
    kernel_matrices = []
    for pattern, subgraphs in pattern_subgraphs.items():
        if pattern == 'graphlet':
            kernel_matrix = graphlet_kernel(subgraphs)
        elif pattern == 'wl_subtree':
            kernel_matrix = wl_subtree_kernel(subgraphs)
        elif pattern == 'shortest_path':
            kernel_matrix = shortest_path_kernel(subgraphs)
        else:
            raise ValueError(f'Unknown pattern: {pattern}')
        kernel_matrices.append(kernel_matrix)
    
    kernel_matrices = [mat.reshape(num_graphs, -1) for mat in kernel_matrices]
    kernel_matrix = np.hstack(kernel_matrices)
    kernel_matrix = kernel_matrix @ np.diag(weights)
    kernel_matrix = kernel_matrix @ kernel_matrix.T
    return kernel_matrix