from .utils import init_solver_stats, update_state, produce_solver_info, produce_dummy_info


__all__ = ['naive_solver', 'speedy_naive_solver']


def naive_solver(func, x0, 
        threshold=50, eps=1e-3, stop_mode='abs', indexing=None, 
        tau=1.0, return_final=True, 
        **kwargs):
    """
    Implements a naive solver for fixed-point iteration.
    
    Args:
        func (callable): The function for which we seek a fixed point.
        x0 (torch.Tensor): The initial guess for the root.
        threshold (int, optional): The maximum number of iterations. Default: 50.
        eps (float, optional): The convergence criterion. Default: 1e-3.
        stop_mode (str, optional): The stopping criterion. Can be either 'abs' or 'rel'. Default: 'abs'.
        indexing (list, optional): List of iteration indices at which to store the solution. Default: None.
        tau (float, optional): Damping factor. It is used to control the step size in the direction of the solution. Default: 1.0.
        return_final (bool, optional): If True, run all steps and returns the final solution instead of the one with smallest residual. Default: True.
        **kwargs: Extra arguments are ignored.

    Returns:
        Tensor: The approximate solution.
        list: List of the solutions at the specified iteration indices.
        dict: A dict containing solver statistics.
    
    Examples:
        >>> f = lambda z: cos(z)                     # Function for which we seek a fixed point
        >>> z0 = torch.tensor([0.0])                 # Initial estimate
        >>> z_star, _, _ = naive_solver(f, z0)       # Run Fixed Point iterations.
        >>> print(z_star)                            # Print the numerical solution
    """
    alternative_mode = 'rel' if stop_mode == 'abs' else 'abs'
    
    trace_dict, lowest_dict, lowest_step_dict = init_solver_stats(x0.shape[0], x0.device)
    lowest_xest = x0

    indexing_list = []
    
    fx = x = x0
    for k in range(threshold):
        x = fx
        fx = tau * func(x) + (1 - tau) * x
        
        # Calculate the absolute and relative differences# Update the state based on the new estimate
        gx = fx - x
        abs_diff = gx.flatten(start_dim=1).norm(dim=1)
        rel_diff = abs_diff / (fx.flatten(start_dim=1).norm(dim=1) + 1e-9)

        # Update the state based on the new estimate
        lowest_xest = update_state(
                lowest_xest, fx, k+1, 
                stop_mode, abs_diff, rel_diff, 
                trace_dict, lowest_dict, lowest_step_dict, return_final
                )

         # If indexing is enabled, store the solution at the specified indices
        if indexing and (k+1) in indexing:
            indexing_list.append(lowest_xest)

        # If the difference is smaller than the given tolerance, terminate the loop early
        if not return_final and trace_dict[stop_mode][-1].max() < eps:
            for _ in range(threshold-1-k):
                trace_dict[stop_mode].append(lowest_dict[stop_mode])
                trace_dict[alternative_mode].append(lowest_dict[alternative_mode])
            break
    
    # at least return the lowest value when enabling  ``indexing''
    if indexing and not indexing_list:
        indexing_list.append(lowest_xest)

    info = produce_solver_info(stop_mode, lowest_dict, trace_dict, lowest_step_dict)
    return lowest_xest, indexing_list, info


def speedy_naive_solver(func, x0, 
        threshold=50, tau=1.0,
        indexing=None, 
        **kwargs):
    """
    Implements a simple fixed-point solver for solving a system of nonlinear equations.
    Speeds up by removing all statistics monitoring.

    Args:
        func (callable): The function for which the fixed point is to be computed.
        x0 (torch.Tensor): The initial guess for the fixed point.
        threshold (int, optional): The maximum number of iterations. Default: 50.
        tau (float, optional): Damping factor to control the step size in the solution direction. Default: 1.0.
        indexing (list, optional): List of iteration indices at which to store the solution. Default: None.
        **kwargs: Additional keyword arguments, ignored in this function.

    Returns:
        torch.Tensor: The approximated fixed point of the function.
        list: List of the solutions at the specified iteration indices.
        dict: A dummy dict for solver statistics.

    Examples:
        >>> f = lambda z: torch.cos(z)
        >>> z0 = torch.tensor([0.0])
        >>> z_star, _, _ = speedy_naive(f, z0)
        >>> print(z_star)
    """

    indexing_list = []
    
    fx = x = x0
    for k in range(threshold):
        x = fx
        fx = func(x, tau=tau)

         # If indexing is enabled, store the solution at the specified indices
        if indexing and (k+1) in indexing:
            indexing_list.append(fx)
    lowest_xest = fx

    # If indexing is enabled but no solution was stored, store the final solution
    if indexing and not indexing_list:
        indexing_list.append(lowest_xest)

    info = produce_dummy_info()
    return lowest_xest, indexing_list, info
