import torch
import torch.distributed as dist

def relative_l1_error(A, B):
    absolute_error = torch.abs(A - B)
    l1_error = torch.sum(absolute_error)
    l1_norm_B = torch.sum(torch.abs(B))
    return l1_error / l1_norm_B

def relative_l2_error(A, B, eps=1e-8):
    l2_error = torch.norm(A - B, p=2)
    l2_norm_B = torch.norm(B, p=2)
    return l2_error / (l2_norm_B + eps)

def find_optimal_alpha(D, E):
    """
    ||D - \alpha E ||
    """
    numerator = torch.sum(D * E)
    denominator = torch.sum(E * E)
    alpha = numerator / (denominator + 1e-8)
    return torch.clamp(alpha, min=0.5, max=1.5)

def save_alpha_dict(cache_dic, task, name="loaded_alpha"):
    """Save ALPHA_DICT with distributed averaging for the new structure."""
    # Gather all alpha dicts from all processes
    if dist.is_initialized():
        gathered_dicts = [None] * dist.get_world_size()
        dist.all_gather_object(gathered_dicts, cache_dic['cache'][name])
        
        # Only rank 0 computes the average and saves
        if dist.get_rank() == 0:
            # Initialize averaged dict
            avg_dict = {
                stream: {key: {} for key in cache_dic['module_list'][stream]} for stream in cache_dic['stream_list']
            }
            # Process each stream
            for stream in cache_dic['stream_list']:
                # Process each module type
                for module in cache_dic['module_list'][stream]:
                    # Collect all steps from all processes
                    all_steps = set()
                    for d in gathered_dicts:
                        if stream in d and module in d[stream]:
                            all_steps.update(d[stream][module].keys())
                    
                    # For each step
                    for step in sorted(all_steps):
                        avg_dict[stream][module][step] = {}
                        
                        # Collect all layers from all processes
                        all_layers = set()
                        for d in gathered_dicts:
                            if stream in d and module in d[stream] and step in d[stream][module]:
                                all_layers.update(d[stream][module][step].keys())
                        
                        # For each layer
                        for layer in sorted(all_layers):
                            values = []
                            for d in gathered_dicts:
                                if (stream in d and module in d[stream] and 
                                    step in d[stream][module] and layer in d[stream][module][step]):
                                    val = d[stream][module][step][layer]
                                    # Ensure value is a scalar
                                    if torch.is_tensor(val):
                                        val = val.item() if val.numel() == 1 else val.mean().item()
                                    values.append(val)
                            
                            # Compute average
                            avg_value = sum(values) / len(values)
                            avg_dict[stream][module][step][layer] = avg_value
            
            # Save the averaged dict
            save_path=f"alpha_dict_{task}.pth"
            torch.save(avg_dict, save_path)
            print(f"Saved averaged ALPHA_DICT to {save_path}")
        dist.barrier()  # Wait for rank 0 to finish saving
    else:
        """Save ALPHA_DICT with distributed averaging for the new structure."""
        # Gather all alpha dicts from all processes
        save_path = f"alpha_dict_{task}.pth"
        if name not in cache_dic['cache']:
            raise ValueError(f"Invalid cache_dic structure {name}")
        # Save the averaged dict
        torch.save(cache_dic['cache'][name], save_path)
        print(f"Saved averaged ALPHA_DICT to {save_path}")