# wd_regularization.py
import torch
import numpy as np
import torch.nn.functional as F

_PRECOMPUTED_MATRICES = {}

def chebyshev_vander(x, N):
    """
    Generate the Chebyshev Vandermonde matrix (vmap-compatible).
    Avoid in-place assignments like `T[:, i] = ...` to maintain compatibility with 
    functional transformations; use list comprehension and `stack` instead.

    Args:
        x: Input tensor, shape [resolution] or [batch, resolution].
        N: Maximum degree (max_degree).

    Returns:
        T: The Vandermonde matrix, shape [..., N+1].
    """
    if N == 0:
        return torch.ones_like(x).unsqueeze(-1)

    # T_0(x) = 1
    one = torch.ones_like(x)
    
    # T_1(x) = x
    T_list = [one, x]
    
    # Recurrence relation: T_n(x) = 2 * x * T_{n-1}(x) - T_{n-2}(x)
    for i in range(2, N + 1):
        next_term = 2 * x * T_list[-1] - T_list[-2]
        T_list.append(next_term)
        
    return torch.stack(T_list, dim=-1)

def precompute_chebyshev_matrix(resolution, max_degree=40, device='cpu'):
    """Precompute Chebyshev matrix and its pseudo-inverse"""
    cache_key = (resolution, max_degree, device)
    if cache_key not in _PRECOMPUTED_MATRICES:
        print(f"Starting precomputation of Chebyshev matrix: resolution={resolution}, max_degree={max_degree}, device={device}")
        
        # Use float64 for precomputation
        k = torch.arange(1, resolution + 1, device=device, dtype=torch.float64)
        x = torch.cos((2 * k - 1) * np.pi / (2 * resolution))
        alpha_values = torch.flip(x, dims=[0])
        
        # Generate Chebyshev matrix
        T_matrix = chebyshev_vander(alpha_values, max_degree)
        
        # Precompute pseudo-inverse matrix: T_pinv = (T^T T)^(-1) T^T
        T_T = T_matrix.T
        A = T_T @ T_matrix
        A_reg = A + 1e-8 * torch.eye(A.shape[0], device=device, dtype=torch.float64)
        
        try:
            A_inv = torch.linalg.inv(A_reg)
        except:
            A_inv = torch.linalg.pinv(A_reg)
        
        T_pinv = A_inv @ T_T
        
        _PRECOMPUTED_MATRICES[cache_key] = {
            'alpha_values': alpha_values,
            'T_matrix': T_matrix,
            'T_pinv': T_pinv  # Precomputed pseudo-inverse matrix
        }
        print(f"Precomputation completed: resolution={resolution}, max_degree={max_degree}, device={device}")
    
    return _PRECOMPUTED_MATRICES[cache_key]
def polynomial_regularization(alpha_values, model_outputs, resolution, miu, max_degree=40, have_const=False, use_norm=True, random_alpha=False, square=False, degree_mode="index", ce_reg=False):
    """
    Optimized version: use solve instead of inv, and enhance damping
    """
    device = model_outputs.device
    original_dtype = model_outputs.dtype 
    
    # Convert to float64 for precision
    model_outputs_64 = model_outputs.to(torch.float64)
    num_classes = model_outputs.shape[1]

    # --- Calculate coefficients Coeffs ---
    if random_alpha:
        # High stability solution for Random Alpha
        alpha_values_64 = alpha_values.to(torch.float64)
        T_matrix = chebyshev_vander(alpha_values_64, max_degree) # [resolution, degree+1]
        
        # Construct normal equations: (T^T * T + lambda * I) * C = T^T * Y
        T_T = T_matrix.T
        A = T_T @ T_matrix
        
        damping = 1e-4 
        A_reg = A + damping * torch.eye(A.shape[0], device=device, dtype=torch.float64)
        
        # RHS = T^T * Y
        # model_outputs_64: [resolution, num_classes]
        rhs = T_T @ model_outputs_64 
        
        # Key modification: use solve instead of inv. solve(A, B) solves AX=B
        # A_reg: [degree+1, degree+1], rhs: [degree+1, num_classes]
        try:
            coeffs = torch.linalg.solve(A_reg, rhs)
        except RuntimeError:
            print("Warning: linalg.solve failed, falling back to lstsq method")
            coeffs = torch.linalg.lstsq(T_matrix, model_outputs_64, driver=None).solution
                
            if coeffs.shape[0] > max_degree + 1:
                coeffs = coeffs[:max_degree + 1]
            
    else:
        # Fixed Alpha uses precomputation logic (keep as is, or also change to solve)
        cached = precompute_chebyshev_matrix(resolution, max_degree, device)
        T_pinv = cached['T_pinv']
        coeffs = T_pinv @ model_outputs_64

    weighted_degree_total = torch.tensor(0.0, device=device, dtype=torch.float64)
    
    if degree_mode == "index":
        degrees = torch.arange(max_degree + 1, device=device, dtype=torch.float64)
    else:
        degrees = torch.ones(max_degree + 1, device=device, dtype=torch.float64)
        degrees[0] = 0.0
        degrees[1] = 0.0
        
    # If constant term is not included
    if not have_const:
        mask = torch.ones_like(coeffs[:, 0])
        mask[0] = 0
        coeffs = coeffs * mask.unsqueeze(1) # [degree+1, num_classes]

    # Vectorized calculation of Weighted Degree (avoid original for loop)
    # coeffs shape: [degree+1, num_classes]
    if not square:
        abs_coeffs = torch.abs(coeffs)
    else:
        abs_coeffs = torch.square(coeffs)
        
    # degrees shape: [degree+1] -> [degree+1, 1]
    degrees_col = degrees.unsqueeze(1) 
    
    numerator = torch.sum(abs_coeffs * degrees_col, dim=0) # [num_classes]
    
    if use_norm:
        denominator = torch.sum(abs_coeffs, dim=0) # [num_classes]
        # Avoid division by zero
        denominator = torch.clamp(denominator, min=1e-10)
        weighted_degree_per_class = numerator / denominator
    else:
        weighted_degree_per_class = numerator

    result = torch.mean(weighted_degree_per_class).to(original_dtype)
    return result