"""
Visualization utilities for Unbiased Zoo.
"""

import numpy as np
import matplotlib.pyplot as plt

def bs_visualization(f, x, estimator, bs_list=None, num_trails=10000, mu=1e-8, grad_threshold=10e-4, a=2.0):
    """Visualize the effect of batch size on gradient estimation.
    
    Args:
        f (callable): The function to differentiate.
        x (numpy.ndarray): The point at which to estimate the gradient.
        estimator (str or GradEstimator): The estimator to use.
        bs_list (list, optional): List of batch sizes to try. Defaults to None.
        num_trails (int, optional): Number of trials for each batch size. Defaults to 10000.
        mu (float, optional): The perturbation size. Defaults to 1e-8.
        grad_threshold (float, optional): Threshold for gradient magnitude. Defaults to 10e-4.
        a (float, optional): Parameter for unbiased estimator. Defaults to 2.0.
        
    Returns:
        numpy.ndarray: Array of MSE values for each batch size and trial.
    """
    from unbiased_zoo.estimators import (
        UniformEstimator, GaussianEstimator, UnbiasedEstimator
    )
    
    if bs_list is None:
        bs_list = [2**i for i in range(10)]

    # Create estimator instances based on the estimator parameter
    estimator_dict = {}
    for bs in bs_list:
        if isinstance(estimator, str):
            if estimator == "uniform":
                estimator_obj = UniformEstimator(zoo_batch_size=bs, mu=mu)
            elif estimator == "gaussian":
                estimator_obj = GaussianEstimator(zoo_batch_size=bs, mu=mu)
            elif estimator == "unbiased":
                estimator_obj = UnbiasedEstimator(zoo_batch_size=bs, mu=mu, a=a)
            else:
                raise ValueError(f"Unknown estimator: {estimator}")
        else:
            # If estimator is already an instance, create a new one with the current batch size
            estimator_class = estimator.__class__
            if hasattr(estimator, 'a'):
                estimator_obj = estimator_class(zoo_batch_size=bs, mu=mu, a=estimator.a)
            else:
                estimator_obj = estimator_class(zoo_batch_size=bs, mu=mu)
        
        estimator_dict[bs] = estimator_obj

    mse_total = np.zeros((num_trails, len(bs_list)))

    for idx, bs in enumerate(bs_list):
        estimator_obj = estimator_dict[bs]
        for trial in range(num_trails):
            grad = f.grad(x)
            est_grad = estimator_obj.estimate(f, x)

            # Calculate overall MSE
            mse_total[trial, idx] = np.mean((grad - est_grad)**2)
            
    return mse_total

def easy_plot(x, mse, color, label):
    """Create a simple plot of MSE values.
    
    Args:
        x (list): List of x values (batch sizes).
        mse (numpy.ndarray): Array of MSE values.
        color (str): Color for the plot.
        label (str): Label for the plot.
    """
    mse_mean = np.mean(mse, axis=0)
    mse_std = np.std(mse, axis=0)
    mse_upper = np.percentile(mse, q=95, axis=0)
    mse_lower = np.percentile(mse, q=5, axis=0)

    plt.fill_between(x, mse_lower, mse_upper, color=color, alpha=0.4)
    plt.plot(x, mse_mean, c=color, label=label) 