from typing import Dict
import torch
import math

def chebyshev_polynomial(n: int, x: torch.Tensor) -> torch.Tensor:
    """
    Compute Chebyshev polynomial T_n(x) using recurrence relation.
    :param n: Order of Chebyshev polynomial
    :param x: Input tensor
    :return: T_n(x)
    """
    if n == 0:
        return torch.ones_like(x)
    elif n == 1:
        return x
    else:
        # Recurrence: T_n(x) = 2*x*T_{n-1}(x) - T_{n-2}(x)
        T_prev_prev = torch.ones_like(x)  # T_0
        T_prev = x  # T_1
        
        for i in range(2, n + 1):
            T_curr = 2 * x * T_prev - T_prev_prev
            T_prev_prev = T_prev
            T_prev = T_curr
        
        return T_prev

def normalize_interval(x: torch.Tensor, a: float, b: float) -> torch.Tensor:
    """
    Normalize x from interval [a, b] to [-1, 1] for Chebyshev approximation.
    :param x: Input values
    :param a: Lower bound of original interval  
    :param b: Upper bound of original interval
    :return: Normalized values in [-1, 1]
    """
    return 2 * (x - a) / (b - a) - 1

def denormalize_interval(x: torch.Tensor, a: float, b: float) -> torch.Tensor:
    """
    Denormalize x from [-1, 1] back to [a, b].
    :param x: Normalized values in [-1, 1]
    :param a: Lower bound of target interval
    :param b: Upper bound of target interval  
    :return: Values in [a, b]
    """
    return (x + 1) * (b - a) / 2 + a

def chebyshev_coefficient_estimation(cache_dic: Dict, current: Dict, feature: torch.Tensor):
    """
    Estimate Chebyshev coefficients using function values at previous steps.
    :param cache_dic: Cache dictionary.
    :param current: Current step information.
    :param feature: Current function value.
    """
    # Define interval bounds based on sampling steps
    interval_length = cache_dic['interval']
    current_step = current['activated_steps'][-1]
    previous_step = current['activated_steps'][-2] if len(current['activated_steps']) > 1 else current_step - interval_length
    
    # Interval for normalization
    a, b = previous_step, current_step
    
    # Store current function value
    updated_chebyshev_coefficients = {}
    updated_chebyshev_coefficients[0] = feature
    
    # Estimate coefficients using finite differences and function interpolation
    for i in range(cache_dic['max_order']):
        if (cache_dic['cache'][-1][current['layer']][current['module']].get(i, None) is not None) and (current['step'] < (current['num_steps'] - cache_dic['first_enhance'] + 1)):
            # Use discrete Chebyshev transform approximation
            # For simplicity, we use a weighted difference approach
            prev_coeff = cache_dic['cache'][-1][current['layer']][current['module']][i]
            curr_coeff = updated_chebyshev_coefficients[i]
            
            # Estimate next coefficient using function differences
            # This is a simplified approach - in practice, you might want more sophisticated methods
            weight = 2.0 / (i + 1)  # Weighting based on Chebyshev properties
            updated_chebyshev_coefficients[i + 1] = weight * (curr_coeff - prev_coeff)
        else:
            break
    
    cache_dic['cache'][-1][current['layer']][current['module']] = updated_chebyshev_coefficients

def chebyshev_approximation(cache_dic: Dict, current: Dict) -> torch.Tensor: 
    """
    Compute Chebyshev polynomial approximation.
    :param cache_dic: Cache dictionary.
    :param current: Current step information.
    :return: Approximated function value
    """
    # Get the time difference
    x_raw = current['step'] - current['activated_steps'][-1]
    
    # Define interval for normalization
    interval_length = cache_dic['interval']
    a, b = -interval_length, 0  # Since x_raw is typically negative
    
    # Normalize x to [-1, 1] for Chebyshev evaluation
    x_normalized = normalize_interval(torch.tensor(float(x_raw)), a, b)
    
    # Ensure x_normalized is on the same device as cached coefficients
    coefficients = cache_dic['cache'][-1][current['layer']][current['module']]
    if len(coefficients) > 0:
        device = next(iter(coefficients.values())).device
        x_normalized = x_normalized.to(device)
    
    # Compute Chebyshev approximation: f(x) ≈ Σ c_i * T_i(x)
    output = 0
    for i in range(len(coefficients)):
        coeff = coefficients[i]
        T_i = chebyshev_polynomial(i, x_normalized)
        
        # Broadcast T_i to match coefficient shape
        while T_i.dim() < coeff.dim():
            T_i = T_i.unsqueeze(-1)
        T_i = T_i.expand_as(coeff)
        
        output += coeff * T_i
    
    return output

def chebyshev_cache_init(cache_dic: Dict, current: Dict):
    """
    Initialize Chebyshev cache for storing coefficients.
    :param cache_dic: Cache dictionary.
    :param current: Current step information.
    """
    if current['step'] == (current['num_steps'] - 1):
        cache_dic['cache'][-1][current['layer']][current['module']] = {}

# Aliases for backward compatibility with Taylor implementation
derivative_approximation = chebyshev_coefficient_estimation
taylor_formula = chebyshev_approximation  
taylor_cache_init = chebyshev_cache_init 