import torch
import torch.distributed as dist


import torch
import torch.distributed as dist

def dynamic_cache_cal_type(cache_dic, current):
    """
    Calculate dynamic cache error type across multiple streams and modules.
    
    Args:
        cache_dic (dict): Dictionary containing cache information including:
            - stream_list: List of streams to process
            - module_list: Dictionary mapping streams to module names
            - cache: Contains dynamic error values
            - accumulative_error: Accumulated error value
        current (dict): Current state information containing 'step' key
    """
    for stream in cache_dic['stream_list']:
        # Get module list for current stream
        module_list = cache_dic['module_list'][stream]
        previous_step = current['step'] - 1
        
        # Get module data from previous step
        module_data = cache_dic['cache']['dynamic_error'][stream][module_list[0]][previous_step]
        local_error = 0
        layer_list = list(module_data.keys())
        
        # Calculate error threshold by finding maximum error across modules per layer
        for layer in layer_list:
            current_max_error = max(
                [cache_dic['cache']['dynamic_error'][stream][module_name][previous_step][layer] 
                for module_name in module_list]
            )
            local_error += current_max_error
        
        # Normalize error by number of layers
        normalized_error = local_error / len(layer_list)
        error_tensor = normalized_error.detach().clone().to('cuda')
        
        # Handle distributed training scenario
        if dist.is_initialized():
            world_size = dist.get_world_size()
            dist.all_reduce(error_tensor, op=dist.ReduceOp.SUM)  # Sum across all devices
            error_tensor = error_tensor / world_size             # Average across devices
        
        # Accumulate error (averaged across streams)
        cache_dic['accumulative_error'] += error_tensor / len(cache_dic['stream_list'])
        
        # Warning for near-zero error values
        if error_tensor < 1e-5:
            print("Warning: Local error nearly zero!")


def compute_threshold(cache_dic):
    """
    Compute the error threshold based on historical errors.
    
    Args:
        cache_dic (dict): Contains 'history_error' list and 'error_rate' value
        
    Returns:
        torch.Tensor: Computed threshold value, or 0.0 if no history exists
    """
    return torch.mean(torch.stack(cache_dic['history_error'])) * cache_dic['error_rate'] \
        if len(cache_dic['history_error']) > 0 else 0.0

def cal_type(self):
    """
    Determine the computation type ('full', 'Scaling', or 'Taylor') for the current step.
    
    The decision is based on:
    - Current step position (first/last steps in sequence)
    - Accumulated error values
    - Cache configuration flags
    - Update requirements for alpha values
    """
    # Calculate dynamic error at each forward pass (after first step)
    # Determine if we're in special step ranges
    first_step = (self.current['step'] < self.cache_dic['first_enhance'])
    last_step = (self.current['step'] >= self.cache_dic['num_steps'] - self.cache_dic['last_enhance'])
    
    if self.cache_dic["dynamic_cache"] and self.current['step'] > 1 and not last_step:
        dynamic_cache_cal_type(self.cache_dic, self.current)
    
    if first_step:
        self.current['type'] = 'full'
        # Record error history only during full computations
        if self.cache_dic['accumulative_error'] != 0:
            self.cache_dic['history_error'].append(
                self.cache_dic['accumulative_error'].detach().clone()
            )
    elif last_step:
        self.current['type'] = 'full'

    elif self.cache_dic['update_alpha']:
        # Force Scaling type when updating alpha values
        self.current['type'] = 'Scaling'
    
    elif self.cache_dic['taylor_cache']:
        # Taylor cache mode: full computation at refresh threshold
        self.current['type'] = 'full' if (
            self.cache_dic['cache_counter'] == self.cache_dic['fresh_threshold'] - 1
        ) else 'Taylor'
    
    elif self.cache_dic['scaling_cache']:
        # Dynamic threshold calculation for scaling cache mode
        self.cache_dic['error_threshold'] = compute_threshold(self.cache_dic)
        
        if self.cache_dic['dynamic_cache']:
            # Full computation when accumulated error exceeds threshold
            self.current['type'] = 'full' if (
                self.cache_dic['accumulative_error'] > self.cache_dic['error_threshold']
            ) else 'Scaling'
        else:
            # Standard periodic refresh for non-dynamic mode
            self.current['type'] = 'full' if (
                self.cache_dic['cache_counter'] == self.cache_dic['fresh_threshold'] - 1
            ) else 'Scaling'
    else:
        # Default to full computation
        self.current['type'] = 'full'

    # Reset or increment cache counter based on computation type
    if self.current['type'] == 'full' or self.cache_dic['update_alpha']:
        self.cache_dic['cache_counter'] = 0
        self.current['activated_steps'].append(self.current['step'])
        self.cache_dic['accumulative_error'] = 0.0  # Reset accumulated error
    else:
        self.cache_dic['cache_counter'] += 1
    print("###")
    print(self.current['activated_steps'])
    print(self.current['step'])
    print("###")