from typing import Dict 
import torch
import math
import torch.nn.functional as F
import os
from .common import relative_l1_error, relative_l2_error, find_optimal_alpha

def interleaved_alpha_init(cache_dic: Dict, current: Dict):
    module_list = cache_dic['module_list'][current['stream']]
    for module in module_list:
        loaded_alpha_dict = cache_dic['cache']['loaded_alpha'][current['stream']][module]
        if current['step'] not in loaded_alpha_dict:
            loaded_alpha_dict[current['step']] = {}

def dynamic_error_init(cache_dic: Dict, current: Dict):
    module_list = cache_dic['module_list'][current['stream']]
    for module in module_list:
        dynamic_error_dict = cache_dic['cache']['dynamic_error'][current['stream']][module]
        if current['step'] not in dynamic_error_dict:
            dynamic_error_dict[current['step']] = {}

def interleaved_error_update(cache_dic: Dict, current: Dict, feature: torch.Tensor):
    module_name = current['module']
    pre_feature = cache_dic['cache'][-1][current['stream']][current['layer']].get(f"{module_name}-pre", None)
    if pre_feature is not None:
        dynamic_error_init(cache_dic, current)
        error_dict = cache_dic['cache']['dynamic_error'][current['stream']][current['module']]
        relative_error = relative_l1_error(
            pre_feature, feature
        )
        error_dict[current['step']][current['layer']] = relative_error

def cal_inter_coef_distance(cache_dic: Dict, current: Dict):
    if (current['activated_steps'][-1] - current['activated_steps'][-2]) == 1:
        return 1.0
    alpha_multi_sum = 0
    alpha_multi = 1.0
    for step in range(current['activated_steps'][-2] + 1, current['activated_steps'][-1] + 1):
        alpha = cache_dic['cache']['loaded_alpha'][current['stream']][current['module']][step][current['layer']] if cache_dic['use_alpha'] else 1.0
        alpha_multi *= alpha
        alpha_multi_sum += alpha_multi
    res = alpha_multi / alpha_multi_sum
    return res

def cal_expand_coef_distance(cache_dic: Dict, current: Dict):
    alpha_multi = 1.0
    alpha_multi_sum = 0
    for step in range(current['activated_steps'][-1] + 1, current['step'] + 1):
        alpha = cache_dic['cache']['loaded_alpha'][current['stream']][current['module']][step][current['layer']] if cache_dic['use_alpha'] else 1.0
        alpha_multi *= alpha
        alpha_multi_sum += alpha_multi
    return alpha_multi_sum

def interleaved_cache_update(cache_dic: Dict, current: Dict, feature: torch.Tensor) -> None:
    """
    Update interleaved cache variables after full computation.
    
    Maintains two key variables:
    1. x_{t-1} (previous feature)
    2. delta = (x_{t-1} - x_{t-2}) * distance_coefficient
    
    The update follows: x_t = x_{t-1} + α * delta
    
    Args:
        cache_dic: Dictionary containing cached features and deltas
        current: Dictionary with current context including:
            - module: Current module name
            - stream: Current processing stream
            - layer: Current layer identifier
        feature: The newly computed feature tensor to cache
    """
    module_name = current['module']
    stream = current['stream']
    layer = current['layer']
    
    # Get the previous feature from cache
    prev_feature = cache_dic['cache'][-1][stream][layer].get(f"{module_name}-pre", None)
    
    # Calculate and store delta if previous feature exists
    if prev_feature is not None:
        delta = (feature - prev_feature) * cal_inter_coef_distance(cache_dic, current)
        cache_dic['cache'][-1][stream][layer][f"{module_name}-delta"] = delta
    
    # Store current feature as new previous feature
    cache_dic['cache'][-1][stream][layer][f"{module_name}-pre"] = feature.clone()


def interleaved_alpha_update(cache_dic: Dict, current: Dict, target_feature: torch.Tensor) -> None:
    """
    Update the interleaved alpha values using an exponentially weighted moving average.
    
    The optimal alpha is found by solving the least squares problem:
    minimize ||(x_{t-1} + α * delta) - target_feature||^2
    
    The final alpha is computed as:
    α_current = 0.03 * α_optimal + 0.97 * α_previous
    
    Args:
        cache_dic: Dictionary containing cached features, deltas and alpha values
        current: Dictionary with current context including:
            - module: Current module name
            - stream: Current processing stream
            - layer: Current layer identifier
            - step: Current time step
        target_feature: The target feature tensor we want to approximate
    """
    module_name = current['module']
    stream = current['stream']
    layer = current['layer']
    step = current['step']
    
    # Initialize interleaved alpha if needed
    interleaved_alpha_init(cache_dic, current)
    
    # Get references to relevant cache entries
    layer_dict = cache_dic['cache'][-1][stream][layer]
    loaded_alpha_dict = cache_dic['cache']['loaded_alpha'][stream][module_name]
    
    # Retrieve cached features and delta
    prev_feature = layer_dict[f'{module_name}-pre']
    delta = layer_dict[f'{module_name}-delta']
    
    # Solve least squares problem: find optimal alpha that minimizes
    # ||(prev_feature + alpha * delta) - target_feature||^2
    # Reformulated as: find alpha that minimizes ||(prev_feature - target_feature) + alpha * (-delta)||^2
    residual = prev_feature - target_feature
    optimal_alpha = find_optimal_alpha(residual, -delta)
    
    # Compute smoothed alpha using exponential moving average
    previous_alpha = loaded_alpha_dict[step].get(layer, 1.0)  # Default to 1.0 if not found
    smoothed_alpha = optimal_alpha * 0.03 + previous_alpha * 0.97
    
    # Update alpha value in cache
    loaded_alpha_dict[step][layer] = smoothed_alpha

def scaling_formula(cache_dic: Dict, current: Dict) -> torch.Tensor:
    """
    Compute scaling output using the formula:
    output = x_prev + delta * expansion_coefficient
    
    Also updates the interleaved error cache with the computed output.

    Args:
        cache_dic: Dictionary containing cached features and coefficients with structure:
            - cache: List of cached states containing previous features and deltas
            - Other metadata used by cal_expand_coef_distance
        current: Dictionary containing current context including:
            - stream: Current processing stream identifier
            - layer: Current layer identifier
            - module: Current module name

    Returns:
        torch.Tensor: The scaling output tensor computed from cached values
    """
    # Get references to relevant cached values
    stream = current['stream']
    layer = current['layer']
    module_name = current['module']
    
    layer_cache = cache_dic['cache'][-1][stream][layer]
    
    # Retrieve cached values
    prev_feature = layer_cache[f'{module_name}-pre']
    delta = layer_cache[f'{module_name}-delta']
    
    # Compute expansion coefficient based on current state
    expansion_coef = cal_expand_coef_distance(cache_dic, current)
    
    # Apply scaling formula: output = x_prev + delta * coefficient
    output = prev_feature + delta * expansion_coef
    
    # Update error tracking with the new output
    interleaved_error_update(
        cache_dic=cache_dic,
        current=current,
        feature=output
    )
    
    return output