import torch
from tqdm import tqdm


def train_lore(init_embedding, lamb, f, g, mu="default", 
                     max_iterations=1000, tol=1e-6, seed=None,
                     zero=1e-15, verbose=True):
    """
    The training loop for the LORE algorithm.

    Args:
        init_embedding (torch array): torch array of shape (N, d')
        lamb (float): Amount of regularization used
        f (LogisticTripletLoss: A LogisticTripletLoss object instantiated with the train triplets
        g (g): either SchattenPUnnormed or Linear
        mu (str, optional float): Calculates Lipschitz Constant lower bound via Power Iteration Method. Can be set manually if needed. Defaults to "default".
        max_iterations (int, optional): Number of iterations to run the algorithm. Defaults to 1000.
        tol (float, optional): Threshold for early stopping due to embedding reaching a stationary point by KL property Defaults to 1e-6.
        seed (int, optional): random seed for (some) reproduciblity. Defaults to None.
        zero (float, optional): Threshhold for automatic 0 in singular value thresholding. Defaults to 1e-15.
        verbose (bool, optional): Whether to print out verbose information. Defaults to True.

    Returns:
        dict: of format
            {
                'X': np.array: the final embedding
                'objectives': list: the objective function values at each iteration
                'f_losses': list: the f loss values at each iteration
                'g_losses': list: the g loss values at each iteration
                'ranks': list: the ranks of the embedding at each iteration
                'sigma_history': list: the singular value history at each iteration
            }
    """
    
    device = init_embedding.device
    if seed is not None:
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(seed)

    # Initialize embedding and tracking
    X = init_embedding.clone().detach().to(device).requires_grad_(True)
    objectives = []
    flosses = []
    glosses = []
    ranks = []
    sigma_history = []

    # Lipschitz constant estimation
    if mu == "default":
        X_clone = X.clone().requires_grad_(True)
        mu = f.estimate_lipschitz_constant(X_clone)
        if verbose:
            print(f"Estimated Lipschitz constant (μ): {mu:.2f}")

    pbar = tqdm(range(max_iterations), disable=not verbose)
    prev_objective = torch.inf

    for _ in pbar:
        # Reset computation graph
        X = X.detach().requires_grad_(True)
        
        # Forward pass
        floss = f(X)
        sigma = torch.linalg.svdvals(X)
        g_sigma = g(sigma)
        gloss = lamb * g_sigma
        current_objective = floss + gloss

        # Track objectives
        objectives.append(current_objective.item())
        flosses.append(floss.item())
        glosses.append(gloss.item())
        if torch.abs(prev_objective - current_objective) < tol:
            if verbose:
                print(f"\nConverged after {len(objectives)} iterations")
            break
        prev_objective = current_objective

        # Compute gradients
        f_grad = torch.autograd.grad(floss, X, retain_graph=False)[0]
        g_grad = torch.autograd.grad(g_sigma, sigma, retain_graph=False)[0]

        # Update step with proper singular value handling
        with torch.no_grad():
            # SVD-based projection
            U, S, V_T = torch.linalg.svd(
                X - f_grad/mu, 
                full_matrices=False,
                driver='gesvd'
            )
            
            # Critical: Apply thresholding and sorting
            new_S = S - (lamb/mu) * g_grad
            rank_mask = new_S > zero
            masked_S = new_S * rank_mask
            
            # Maintain sorted singular values for next iteration
            sorted_S, _ = torch.sort(masked_S, descending=True)
            sigma = sorted_S  # Store for next gradient calculation
            sigma_history.append(sigma.cpu().numpy())
            
            X = U @ torch.diag(masked_S) @ V_T

            ranks.append(torch.sum(rank_mask).item())
            # KL distance for convergence check
            KLdist = torch.linalg.matrix_norm(X - init_embedding, ord=torch.inf)
            if KLdist < tol:
                if verbose:
                    print(f"\nKLDist Converged after {len(objectives)} iterations")
                break
        # Update progress bar metrics
        pbar.set_postfix({
            'f_loss': f"{floss.item():.4f}",
            'g_loss': f"{gloss.item():.4f}",
            'objective': f"{current_objective.item():.4f}",
            'KLdist': f"{KLdist.item():.4f}",
            'rank': torch.sum(rank_mask).item()
        })

    results = {
        'X': X.detach().cpu().numpy(),
        'objectives': objectives,
        'f_losses': flosses,
        'g_losses': glosses,
        'ranks': ranks,
        'sigma_history': sigma_history
    }
    return results

