import numpy as np

def compute_summary_stats(metric_dict:dict, metric_name:str, bucket_size:int=50, clip:int=0):
    """
    Computes summary statistics for a specific metric from a given metric dictionary.

    Args:
        metric_dict (dict): A dictionary containing the metric data.
        metric_name (str): The name of the metric to compute summary statistics for.
        bucket_size (int, optional): The size of the buckets for averaging epochs. Defaults to 100.

    Returns:
        tuple: A tuple containing the mean and standard error of the metric.
    """

    metric = metric_dict[metric_name]
    # convert to numpy array
    metric = np.array(metric)
    mean_metric = np.mean(metric, axis=0)
    se_metric = np.std(metric, axis=0)/np.sqrt(metric.shape[0])
    # average epochs in buckets
    if clip > 0:
        mean_metric = mean_metric[:-clip]
        se_metric = se_metric[:-clip]
    mean_metric = np.mean(mean_metric.reshape(-1, bucket_size), axis=1)
    se_metric = np.mean(se_metric.reshape(-1, bucket_size), axis=1)
    return mean_metric, se_metric