import numpy as np
import torch


def init_solver_stats(bsz, device, init_loss=1e8):
    """
    Initializes the dictionaries for solver statistics.

    Args:
        bsz (int): Batch size.
        device (torch.device): Device on which the tensors will be allocated.
        init_loss (float, optional): Initial loss value. Default is 1e8.

    Returns:
        tuple: A tuple containing three dictionaries for tracking absolute and relative differences and steps.
    """
    trace_dict = {
            'abs': [torch.tensor(init_loss, device=device).repeat(bsz)],
            'rel': [torch.tensor(init_loss, device=device).repeat(bsz)]
            }
    lowest_dict = {
            'abs': torch.tensor(init_loss, device=device).repeat(bsz),
            'rel': torch.tensor(init_loss, device=device).repeat(bsz)
            }
    lowest_step_dict = {
            'abs': torch.tensor(0., device=device).repeat(bsz),
            'rel': torch.tensor(0., device=device).repeat(bsz),
            }

    return trace_dict, lowest_dict, lowest_step_dict


def batch_masked_mixing(mask, mask_var, orig_var):    
    """
    Helper function. First aligns the axes of mask to mask_var.
    Then mixes mask_var and orig_var through the aligned mask.

    Applies a mask to 'mask_var' and the inverse of the mask to 'orig_var', then sums the result.

    Args:
        mask (torch.Tensor): A tensor of shape (B,).
        mask_var (torch.Tensor): A tensor of shape (B, ...) for the mask to select.
        orig_var (torch.Tensor): A tensor of shape (B, ...) for the reversed mask to select.

    Returns:
        torch.Tensor: A tensor resulting from the masked mixture of 'mask_var' and 'orig_var'.
    """

    if torch.is_tensor(mask_var):
        axes_to_align = len(mask_var.shape) - 1
    elif torch.is_tensor(orig_var):
        axes_to_align = len(orig_var.shape) - 1
    else:
        raise ValueError('Either mask_var or orig_var should be a Pytorch tensor!')
    
    aligned_mask = mask.view(mask.shape[0], *[1 for _ in range(axes_to_align)])

    return aligned_mask * mask_var + ~aligned_mask * orig_var


def update_state(
        lowest_xest, x_est, nstep, 
        stop_mode, abs_diff, rel_diff, 
        trace_dict, lowest_dict, lowest_step_dict, 
        return_final=False
        ):
    """
    Updates the state of the solver during each iteration.

    Args:
        lowest_xest (torch.Tensor): Tensor of lowest fixed point error.
        x_est (torch.Tensor): Current estimated solution.
        nstep (int): Current step number.
        stop_mode (str): Mode of stopping criteria ('rel' or 'abs').
        abs_diff (torch.Tensor): Absolute difference between estimates.
        rel_diff (torch.Tensor): Relative difference between estimates.
        trace_dict (dict): Dictionary to trace absolute and relative differences.
        lowest_dict (dict): Dictionary storing the lowest differences.
        lowest_step_dict (dict): Dictionary storing the steps at which the lowest differences occurred.
        return_final (bool, optional): Whether to return the final estimated value. Default False.

    Returns:
        torch.Tensor: Updated tensor of lowest fixed point error.
    """
    diff_dict = {'abs': abs_diff,
                     'rel': rel_diff}
    trace_dict['abs'].append(abs_diff)
    trace_dict['rel'].append(rel_diff)
 
    for mode in ['rel', 'abs']:
        is_lowest = (diff_dict[mode] < lowest_dict[mode]) + return_final
        if mode == stop_mode:
            lowest_xest = batch_masked_mixing(is_lowest, x_est, lowest_xest)
            lowest_xest = lowest_xest.clone().detach() 
        lowest_dict[mode] = batch_masked_mixing(is_lowest, diff_dict[mode], lowest_dict[mode])
        lowest_step_dict[mode] = batch_masked_mixing(is_lowest, nstep, lowest_step_dict[mode])

    return lowest_xest


def produce_solver_info(
        stop_mode, lowest_dict, trace_dict, lowest_step_dict
        ):
    """
    Generates a dict with solver statistics.

    Args:
        stop_mode (str): Mode of stopping criteria ('rel' or 'abs').
        lowest_dict (dict): Dictionary storing the lowest differences.
        trace_dict (dict): Dictionary to trace absolute and relative differences.
        lowest_step_dict (dict): Dictionary storing the steps at which the lowest differences occurred.

    Returns:
        dict: A dict containing solver statistics.
    """
    info = {
            'abs_lowest': lowest_dict['abs'],
            'rel_lowest': lowest_dict['rel'],
            'abs_trace': torch.stack(trace_dict['abs'], dim=1),
            'rel_trace': torch.stack(trace_dict['rel'], dim=1),
            'nstep': lowest_step_dict[stop_mode], 
            }

    return info


def produce_final_info(
        z, fz, nstep=0,
        ):
    """
    Generates a dict with final-step solver statistics.

    Args:
        z (torch.Tensor): Final fixed point estimate.
        fz (torch.Tensor): Function evaluation of final fixed point estimate.
        nstep (int, optional): Total number of steps in the solver. Default 0.

    Returns:
        dict: A dict with final-step solver statistics
    """
    if not torch.is_tensor(z):
        z = torch.cat([zi.flatten(start_dim=1) for zi in z], dim=1)
        fz = torch.cat([fi.flatten(start_dim=1) for fi in fz], dim=1)

    diff = fz - z
    abs_lowest = diff.flatten(start_dim=1).norm(dim=1)
    rel_lowest = abs_lowest / (fz.flatten(start_dim=1).norm(dim=1) + 1e-8)
    nstep = nstep * torch.ones(z.shape[0], device=z.device)
    info = {
            'abs_lowest': abs_lowest,
            'rel_lowest': rel_lowest,
            'abs_trace': abs_lowest.unsqueeze(dim=1),
            'rel_trace': rel_lowest.unsqueeze(dim=1),
            'nstep': nstep, 
            }

    return info


def produce_dummy_info():
    """
    Generates a dummy solver statistics dict.

    Returns:
        dict: A dict containing dummy information.
    """
    return {
        'abs_lowest': torch.tensor([-1.]),
        'rel_lowest': torch.tensor([-1.]),
        'abs_trace': torch.tensor([-1.]),
        'rel_trace': torch.tensor([-1.]),
        'nstep': torch.tensor([-1.]),
        }
