# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F # Needed for DigitModel, ResNet
import math                    # Needed for VGG init
import re                      # Needed for epoch range parsing
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import argparse
import os
import copy
from tqdm import tqdm
import gc
import random
from torch.utils.data import DataLoader, Subset, TensorDataset, Dataset
import torchvision.transforms as transforms
import torchvision.models as tv_models # Keep for standard ResNet if needed


# --- Model Definitions (Based on User Snippets) ---

# ResNet BasicBlock and conv3x3 (Standard Implementation)
def conv3x3(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)

class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(in_planes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes); self.relu1 = nn.ReLU(inplace=True) # Separate ReLU
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes) )
        self.relu2 = nn.ReLU(inplace=True) # Separate ReLU for final output
    def forward(self, x):
        identity = x
        out = self.conv1(x); out = self.bn1(out); out = self.relu1(out)
        out = self.conv2(out); out = self.bn2(out)
        out += self.shortcut(identity)
        out = self.relu2(out)
        return out

class ResNet(nn.Module):
    def __init__(self, block=BasicBlock, num_blocks=[2,2,2,2], num_classes=10, feature_output=False, in_channels=3): # Added in_channels
        super(ResNet, self).__init__()
        self.in_planes = 64
        self.outputfeature=feature_output
        self.conv1 = conv3x3(in_channels, 64) # Use in_channels
        self.bn1 = nn.BatchNorm2d(64); self.relu = nn.ReLU(inplace=True) # Single ReLU after bn1
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.linear = nn.Linear(512*block.expansion, num_classes)
    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1); layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)
    def forward(self, x, output_feature= False):
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out); out = self.layer2(out); out = self.layer3(out); out = self.layer4(out)
        out = self.avgpool(out)
        feature_map = out # Keep feature map before flatten if needed elsewhere
        flat_features = out.view(out.size(0), -1)
        final_output = self.linear(flat_features)
        if self.outputfeature or output_feature: return final_output, flat_features
        else: return final_output

def ResNet18(num_classes=10, feature_output=False, in_channels=3): # Pass in_channels
    return ResNet(BasicBlock, [2,2,2,2], num_classes=num_classes, feature_output=feature_output, in_channels=in_channels)

# User's VGG definition
defaultcfg = {
    16 : [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'V'],
    # Add other depths if needed
}

class vgg(nn.Module):
    def __init__(self, args=None, dataset='cifar10', depth=16, init_weights=True, cfg=None, num_classes=10, in_channels=3): # Added num_classes, in_channels
        super(vgg, self).__init__()
        if args: # Override defaults from args if provided
             dataset = getattr(args, 'dataset', dataset)
             depth = getattr(args, 'vgg_depth', depth)
             num_classes = getattr(args, 'num_classes', num_classes)
             # Infer in_channels from dataset in args? Safer to pass explicitly or handle in make_layers
             if args.dataset == 'domain_digits': # Example inference
                  print(f"Warning: Assuming 3 input channels for VGG on {dataset}.")
                  in_channels = 3
             else: in_channels = 3 # Default others to 3

        if cfg is None: cfg = defaultcfg[depth]
        self.cfg = cfg
        self.feature_extractor = self.make_layers(cfg, True, in_channels=in_channels) # Renamed for clarity
        self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) # Standard VGG final feature map size
        # Find the number of output channels from the last conv layer in cfg
        last_conv_channels = 512 # Default fallback
        for x in reversed(cfg):
             if isinstance(x, int): last_conv_channels = x; break
        classifier_input_features = last_conv_channels * 7 * 7

        self.classifier = nn.Sequential(
            nn.Linear(classifier_input_features, 4096), nn.ReLU(True), nn.Dropout(0.5),
            nn.Linear(4096, 4096), nn.ReLU(True), nn.Dropout(0.5),
            nn.Linear(4096, num_classes) )
        if init_weights: self._initialize_weights()

    def make_layers(self, cfg, batch_norm=False, in_channels=3):
        layers = [] # Use standard list for appending ModuleLists or single Modules
        for v in cfg:
            if v == 'M': layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
            elif v == 'V': layers.append(nn.AvgPool2d(kernel_size=2, stride=2))
            else:
                conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1, bias=False)
                if batch_norm: layers.extend([conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)])
                else: layers.extend([conv2d, nn.ReLU(inplace=True)])
                in_channels = v
        return nn.Sequential(*layers)

    def forward(self, x, output_feature= False):
        features = self.feature_extractor(x) # Use renamed attribute
        pooled_features = self.avgpool(features)
        flattened_features = pooled_features.view(pooled_features.size(0), -1)
        y = self.classifier(flattened_features)
        if output_feature: return y, flattened_features
        else: return y

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None: m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1); m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.01)
                if m.bias is not None: m.bias.data.zero_()

# User's DigitModel Definition
class DigitModel(nn.Module):
    def __init__(self, num_classes=10, feature_output=False, in_channels=3, **kwargs): # Added in_channels
        super(DigitModel, self).__init__()
        self.output_feature=feature_output
        self.conv1 = nn.Conv2d(in_channels, 64, 5, 1, 2) # Use in_channels
        self.bn1 = nn.BatchNorm2d(64); self.relu1 = nn.ReLU(inplace=True) # Separate ReLU
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(64, 64, 5, 1, 2)
        self.bn2 = nn.BatchNorm2d(64); self.relu2 = nn.ReLU(inplace=True)
        self.pool2 = nn.MaxPool2d(2)
        self.conv3 = nn.Conv2d(64, 128, 5, 1, 2)
        self.bn3 = nn.BatchNorm2d(128); self.relu3 = nn.ReLU(inplace=True)
        # Assuming 32x32 input -> 8x8 output from conv layers
        self._fc1_in_features = 128 * 8 * 8 # Recalculate based on 32x32 input assumption
        self.flatten = nn.Flatten() # Use nn.Flatten
        self.fc1 = nn.Linear(self._fc1_in_features, 2048)
        self.bn4 = nn.BatchNorm1d(2048); self.relu4 = nn.ReLU(inplace=True)
        self.fc2 = nn.Linear(2048, 512)
        self.bn5 = nn.BatchNorm1d(512); self.relu5 = nn.ReLU(inplace=True)
        self.fc3 = nn.Linear(512, num_classes)

    def forward(self, x, output_feature= False):
        out = self.pool1(self.relu1(self.bn1(self.conv1(x))))
        out = self.pool2(self.relu2(self.bn2(self.conv2(out))))
        out = self.relu3(self.bn3(self.conv3(out)))
        out = self.flatten(out)
        # Optional check: print(f"Actual flattened size: {out.shape[1]}, Expected: {self._fc1_in_features}")
        if out.shape[1] != self._fc1_in_features:
             print(f"Warning: Flattened size mismatch. Got {out.shape[1]}, expected {self._fc1_in_features}. Check input size or model.")
             # Handle mismatch if possible, e.g., adaptive pooling before flatten, or adjust linear layer
             # For now, allow potential error downstream
        out = self.relu4(self.bn4(self.fc1(out)))
        feature=out # Feature after first FC layer's ReLU
        out = self.relu5(self.bn5(self.fc2(out)))
        final_output = self.fc3(out)
        if self.output_feature or output_feature: return final_output, feature
        else: return final_output

# --- End Model Definitions ---

# --- User's Project Specific Imports ---
try:
    # from utils.options import args_parser as project_args_parser
    from utils.init_data_model import init_data, init_model, init_data_methodone
except ImportError as e:
    print(f"Warning: Could not import user-specific utils: {e}")
    project_args_parser = None
    # Dummy functions... (similar to before)
    def init_data(args): return [], [], None
    def init_data_methodone(args): return init_data(args)
    def init_model(args): return get_model_instance(args) # Use updated getter
# --- End User Specific Imports ---


# --- Helper Function for Model Instantiation (Revised)---
def get_model_instance(args):
    """Instantiates the appropriate model based on args, using defined classes."""
    print(f"Attempting to instantiate model: {args.model}")
    num_classes = getattr(args, 'num_classes', 10)
    feature_output = getattr(args, 'feature_output', False) # Check if feature_output needed

    # Determine input channels based on dataset (provide a default)
    if args.dataset == 'domain_digits':
         # MNIST/USPS are 1 channel, SVHN/Synth/MNIST-M are 3.
         # Models need to handle this, or data needs consistent preprocessing.
         # Let's assume 3 channels as DigitModel/VGG/ResNet typically expect color.
         # User might need to add transforms to convert 1ch to 3ch in init_data.
         in_channels = 3
         print(f"Warning: Assuming 3 input channels for dataset '{args.dataset}'. Ensure data preprocessing matches.")
    elif args.dataset in ['office-caltech10', 'DomainNet']:
         in_channels = 3
    else:
         in_channels = 3 # Default assumption

    if args.model == 'cnn' or args.model == 'digit':
        print("Instantiating DigitModel.")
        return DigitModel(num_classes=num_classes, feature_output=feature_output, in_channels=in_channels)
    elif args.model == 'vgg16':
        print("Instantiating VGG16 model.")
        return vgg(args=args, depth=16, num_classes=num_classes, in_channels=in_channels)
    elif args.model == 'resnet18':
        print("Instantiating ResNet18 model.")
        return ResNet18(num_classes=num_classes, feature_output=feature_output, in_channels=in_channels)
    else:
        print(f"Warning: Unknown model type '{args.model}'.")
        return None # Return None if model unknown


# --- Function to get Calibration DataLoader (Unchanged from fix) ---
def get_calibration_dataloader_from_existing(args, all_loaders, client_idx=None, global_mix=False, num_samples=128):
    """Creates a calibration DataLoader from pre-loaded dataset parts."""
    target_dataset = None
    if global_mix:
        print("Creating global mixed calibration dataset...")
        mixed_indices = []; num_clients_available = len(all_loaders)
        if num_clients_available == 0: return None
        samples_per_client = max(1, num_samples // num_clients_available); actual_num_samples = 0
        for i in range(num_clients_available):
            if i >= len(all_loaders) or not hasattr(all_loaders[i], 'dataset'): continue
            client_dataset = all_loaders[i].dataset
            if client_dataset is None or not hasattr(client_dataset, '__len__'): continue
            client_size = len(client_dataset);
            if client_size == 0: continue
            indices = list(range(client_size)); random.shuffle(indices)
            num_to_take = min(samples_per_client, client_size)
            mixed_indices.extend( [(i, idx) for idx in indices[:num_to_take]] ); actual_num_samples += num_to_take
        if not mixed_indices:
             print("Warning: Could not gather samples for global mix. Falling back to client 0.")
             if not all_loaders or not hasattr(all_loaders[0], 'dataset') or all_loaders[0].dataset is None: return None
             target_dataset = all_loaders[0].dataset; num_samples = min(num_samples, len(target_dataset))
        else:
            print(f"Gathering {len(mixed_indices)} samples from {num_clients_available} clients for global mix.")
            mixed_data = []; mixed_labels = []; successful_samples = 0
            for loader_idx, sample_idx in mixed_indices:
                 try:
                     item = all_loaders[loader_idx].dataset[sample_idx]
                     if not isinstance(item, (list, tuple)) or len(item) < 2: continue
                     data = item[0]; label = item[1]
                     if not isinstance(data, torch.Tensor):
                          try:
                              if "PIL." in str(type(data)): data = transforms.ToTensor()(data)
                              elif isinstance(data, np.ndarray):
                                   data = torch.from_numpy(data)
                                   if data.dim() == 2: data = data.unsqueeze(0)
                                   if data.dim() == 3 and data.shape[-1] in [1, 3]: data = data.permute(2, 0, 1)
                              else: data = transforms.ToTensor()(data)
                              if data.dim() < 2 or data.dim() > 4: continue
                          except Exception: continue
                     if not isinstance(label, (torch.Tensor, int, float, np.number)): continue
                     if data.dim() < 2 : continue
                     mixed_data.append(data.unsqueeze(0))
                     label_tensor = torch.tensor(label) if not isinstance(label, torch.Tensor) else label.clone().detach()
                     mixed_labels.append(label_tensor.unsqueeze(0)); successful_samples += 1
                 except Exception: continue
            if not mixed_data:
                  print("Warning: Failed to process samples for global mix. Falling back to client 0.")
                  if not all_loaders or not hasattr(all_loaders[0], 'dataset') or all_loaders[0].dataset is None: return None
                  target_dataset = all_loaders[0].dataset; num_samples = min(num_samples, len(target_dataset))
            else:
                 print(f"Successfully processed {successful_samples} samples for global TensorDataset.")
                 try:
                      target_dataset = TensorDataset(torch.cat(mixed_data, dim=0), torch.cat(mixed_labels, dim=0).long())
                 except Exception as e:
                      print(f"Error creating TensorDataset: {e}. Falling back to client 0.")
                      if not all_loaders or not hasattr(all_loaders[0], 'dataset') or all_loaders[0].dataset is None: return None
                      target_dataset = all_loaders[0].dataset; num_samples = min(num_samples, len(target_dataset))
            num_samples = len(target_dataset)
    elif client_idx is not None and client_idx < len(all_loaders):
        print(f"Using calibration data from Client {client_idx}")
        if not hasattr(all_loaders[client_idx], 'dataset') or all_loaders[client_idx].dataset is None: return None
        target_dataset = all_loaders[client_idx].dataset; num_samples = min(num_samples, len(target_dataset))
    else:
        print(f"Warning: Invalid client index/loader. Falling back to client 0.")
        if not all_loaders or not hasattr(all_loaders[0], 'dataset') or all_loaders[0].dataset is None: return None
        target_dataset = all_loaders[0].dataset; num_samples = min(num_samples, len(target_dataset))

    dataset_size = len(target_dataset)
    if dataset_size == 0: return None
    num_samples = min(num_samples, dataset_size)
    if num_samples <= 0: return None
    subset_indices = list(range(num_samples))
    calibration_subset = Subset(target_dataset, subset_indices)
    calibration_loader = DataLoader(calibration_subset, batch_size=args.batch_size, shuffle=False, drop_last=True, num_workers=0)
    if len(calibration_loader) == 0 and len(calibration_subset) > 0: print(f"Warning: Calibration DataLoader 0 batches.")
    elif len(calibration_loader) > 0: print(f"Calibration loader created: {len(calibration_subset)} samples -> {len(calibration_loader)} batches.")
    return calibration_loader


# --- Core Importance Calculation (Revised for Conv/Linear) ---
@torch.no_grad()
def calculate_importance_scores(model, dataloader, device, args, score_type='local', normalize_globally=False, global_stats=None):
    """
    Calculates activation scores for Conv2d and Linear layers, processing batch-by-batch.

    Args:
        model: The model to analyze.
        dataloader: DataLoader for the data to use.
        device: Device (cuda or cpu).
        args: Argparse namespace.
        score_type: 'local' or 'global', used potentially for logging or future extensions.
        normalize_globally: Whether to normalize scores globally (requires global_stats).
        global_stats: Dictionary containing 'min' and 'max' for global normalization.

    Returns:
        A dictionary mapping layer names to their importance scores.
    """
    if dataloader is None or len(dataloader) == 0:
        print(f"DataLoader is empty for score calculation ({score_type}). Returning empty scores.")
        return {}

    model.eval()
    model.to(device)

    # Dictionary to store hooks temporarily for each batch
    # We need to capture activations *during* the forward pass
    # and process them *immediately* after for the current batch.
    # A simple way is to let the hook append, and then process the LAST element added per layer.
    activations = {} # This will accumulate activation tensors from each batch
    hook_handles = []
    target_layers = {}

    def get_activation_hook(name):
        def hook(module, input, output):
            # Ensure output is a tensor and has dimensions > 1 (not just batch size)
            # For some layers (like BatchNorm after Conv), input might be more relevant,
            # but sticking to output for now as per original code.
            output_tensor = None
            if isinstance(output, tuple):
                # Try to find a tensor output in the tuple
                output_tensor = next((t for t in output if isinstance(t, torch.Tensor) and t.dim() > 1), None)
            elif isinstance(output, torch.Tensor) and output.dim() > 1:
                output_tensor = output

            if output_tensor is not None:
                if name not in activations:
                    activations[name] = []
                # Append the activation tensor for this batch
                activations[name].append(output_tensor.detach().cpu()) # Detach and move to CPU immediately

        return hook

    # Register hooks
    print(f"[{score_type}] Registering hooks for Conv2d and Linear layers...")
    for name, module in model.named_modules():
        if isinstance(module, (nn.Conv2d, nn.Linear)):
             # Exclude final layer if specified
             is_final_layer = isinstance(module, nn.Linear) and name.split('.')[-1] in ['fc', 'linear', 'classifier', 'fc3', 'out']
             if getattr(args, 'exclude_final_layer', False) and is_final_layer:
                  print(f"[{score_type}] Skipping hook for likely final layer: {name}")
                  continue
             target_layers[name] = module
             # Only register hook if we need scores for this layer
             # This might be an optimization if re_layers/retain_layers are known here
             # But for generality, register for all relevant layers and filter later
             handle = module.register_forward_hook(get_activation_hook(name))
             hook_handles.append(handle)

    if not target_layers:
        print(f"[{score_type}] No target layers found for activation calculation. Returning empty scores.")
        # Remove hooks even if no layers were found
        for handle in hook_handles: handle.remove()
        return {}

    print(f"[{score_type}] Calculating activations batch by batch and accumulating scores...")
    layer_scores_sum = {} # Sum of scores per layer across batches
    successful_batches = 0

    # Make sure initial sums are zero tensors with the correct shape (based on first batch)
    # We'll initialize layer_scores_sum upon processing the first batch

    with torch.no_grad():
        for batch_idx, batch_data in tqdm(enumerate(dataloader), total=len(dataloader), desc=f"[{score_type}] Calibration"):
            try:
                # Data loading
                if isinstance(batch_data, (list, tuple)):
                    data = batch_data[0].to(device) if len(batch_data) > 0 and isinstance(batch_data[0], torch.Tensor) else None
                elif isinstance(batch_data, torch.Tensor):
                    data = batch_data.to(device)
                else:
                    data = None

                if data is None or data.dim() < 2:
                    print(f"[{score_type}] Skipping batch {batch_idx}: Invalid data.")
                    # Clear any activations that might have been collected partially
                    activations.clear() # Clear any activations from this failed batch
                    gc.collect()
                    torch.cuda.empty_cache()
                    continue

                # Perform forward pass - hooks will populate 'activations' with *this batch's* tensors
                # The tensors for this batch are the ones most recently appended to the lists in 'activations'
                model(data)
                successful_batches += 1

                # >>> Process Activations for the CURRENT Batch <<<
                for name in target_layers.keys():
                    if name in activations and activations[name]:
                        # Get the activation tensor for the current batch (it's the last one added)
                        current_batch_act = activations[name][-1]

                        # Calculate the score for THIS BATCH's activation
                        module = target_layers[name] # Get module reference
                        batch_score = None

                        try:
                             if isinstance(module, nn.Conv2d) and current_batch_act.dim() == 4:
                                  # Mean absolute value over spatial dims (H, W) and batch dim
                                  # Resulting shape should be [C_out]
                                  if current_batch_act.shape[1:].numel() > 0:
                                      batch_score = torch.mean(torch.abs(current_batch_act), dim=(0, 2, 3))
                             elif isinstance(module, nn.Linear) and current_batch_act.dim() >= 2:
                                  # Mean absolute value over batch dim (and potentially other intermediate dims)
                                  # Resulting shape should be [num_features] (last dim)
                                  if current_batch_act.shape[-1] > 0:
                                       avg_dims = tuple(range(current_batch_act.dim() - 1))
                                       if avg_dims:
                                            batch_score = torch.mean(torch.abs(current_batch_act), dim=avg_dims)
                                       elif current_batch_act.dim() == 2: # Handle [Batch, Features]
                                            batch_score = torch.mean(torch.abs(current_batch_act), dim=0)

                             if batch_score is not None:
                                 # Initialize sum for this layer on the first batch it's processed
                                 if name not in layer_scores_sum:
                                     layer_scores_sum[name] = torch.zeros_like(batch_score)

                                 # Add batch score to the sum
                                 layer_scores_sum[name] += batch_score.cpu() # Ensure sums are on CPU

                        except Exception as e:
                             print(f"[{score_type}] Error processing batch activation for layer {name}: {e}. Skipping this batch's score for this layer.")
                             # No accumulation for this batch/layer

                        finally:
                            # >>> CRITICAL: Remove and Delete the activation tensor for THIS batch <<<
                            # This frees memory used by the activation tensor of the current batch
                            if name in activations and activations[name]: # Check if list exists and is not empty
                                # del activations[name][-1] # REMOVE THIS LINE - REDUNDANT/PROBLEMatic
                                popped_act = activations[name].pop() # Correct way to remove last element
                                if not activations[name]: # If list becomes empty, remove the key
                                    del activations[name] # Remove the list itself
                                del popped_act # Ensure tensor is deleted

                            # torch.cuda.empty_cache() # Might be too frequent here, better after the loop over layers per batch

                # Clean up CUDA cache after processing all layers for the current batch
                torch.cuda.empty_cache()
                gc.collect() # Optional: Explicit GC after a batch

            except Exception as e:
                print(f"[{score_type}] Fwd pass or batch processing error batch {batch_idx}: {e}")
                # Clear any activations that might have been collected partially due to error
                activations.clear()
                gc.collect()
                torch.cuda.empty_cache()
                continue # Skip to the next batch

    # Remove hooks after processing all batches
    for handle in hook_handles:
        handle.remove()
    hook_handles.clear() # Clear the list of handles

    # Ensure all temporary activation tensors are gone
    activations.clear()
    gc.collect()
    torch.cuda.empty_cache()


    print(f"[{score_type}] Processed {successful_batches} batches.")
    if successful_batches == 0:
        print(f"[{score_type}] No successful batches processed. Cannot calculate scores.")
        return {}

    # Calculate average scores
    calculated_scores = {}
    print(f"[{score_type}] Calculating average scores...")
    for name, total_score in tqdm(layer_scores_sum.items(), desc=f"[{score_type}] Averaging Scores"):
        # Divide the sum by the number of batches processed successfully
        if name in target_layers: # Ensure the layer is one we intended to process
             # Average over the batch dimension implicitly by dividing sum by count
             average_score = total_score / successful_batches
             calculated_scores[name] = average_score.to(device) # Move result to device if needed later

    # Apply normalization
    print(f"[{score_type}] Applying normalization...")
    final_scores = {}
    if normalize_globally:
        print(f"[{score_type}] Applying global normalization...")
        if global_stats is None or 'min' not in global_stats or 'max' not in global_stats:
            print(f"[{score_type}] Warning: Global stats not provided or incomplete. Skipping global normalization.")
            final_scores = calculated_scores # Use unnormalized scores
        else:
            g_min = global_stats['min'].to(device)
            g_max = global_stats['max'].to(device)
            if torch.any(g_max <= g_min):
                 print(f"[{score_type}] Warning: Global max <= global min. Skipping global normalization.")
                 final_scores = calculated_scores # Use unnormalized scores
            else:
                for name, score in calculated_scores.items():
                     # Clamp score within global range before normalizing
                     clamped_score = torch.clamp(score, g_min[name] if name in g_min else torch.min(score), g_max[name] if name in g_max else torch.max(score))
                     norm_score = (clamped_score - (g_min[name] if name in g_min else torch.min(score))) / ((g_max[name] if name in g_max else torch.max(score)) - (g_min[name] if name in g_min else torch.min(score)) + 1e-8)
                     final_scores[name] = torch.clamp(norm_score, 0, 1) # Clamp to [0, 1]
    else:
        print(f"[{score_type}] Applying layer-wise normalization...")
        for name, score in calculated_scores.items():
            min_v, max_v = torch.min(score), torch.max(score)
            if torch.isnan(min_v) or torch.isnan(max_v):
                 print(f"[{score_type}] Warning: NaN in score for layer {name}. Skipping layer-wise normalization.")
                 norm_score = score # Use raw if stats bad
            elif max_v > min_v:
                 norm_score = (score - min_v) / (max_v - min_v + 1e-8)
            elif max_v == min_v:
                 # If all values are the same: 0 if zero, 1 if non-zero
                 norm_score = torch.zeros_like(score) if max_v == 0 else torch.ones_like(score)
            else: # Should not happen if max_v >= min_v
                 print(f"[{score_type}] Warning: min_v > max_v for layer {name}. Defaulting to ones.")
                 norm_score = torch.ones_like(score)

            # Clamp to [0, 1] range after normalization
            final_scores[name] = torch.clamp(norm_score, 0, 1)

    # Final cleanup
    del layer_scores_sum, calculated_scores
    gc.collect()
    torch.cuda.empty_cache()

    return final_scores


# --- Helper Function for Global Stats (Unchanged) ---
@torch.no_grad()
def calculate_global_stats(all_scores_dicts):
    all_raw_vals = []
    for scores_dict in all_scores_dicts:
        for score_tensor in scores_dict.values():
            if score_tensor is not None and score_tensor.numel() > 0:
                all_raw_vals.append(score_tensor.cpu().float().flatten())
    if not all_raw_vals: return {'min': 0.0, 'max': 1.0}
    valid_vals = torch.cat(all_raw_vals)
    valid_vals = valid_vals[~torch.isnan(valid_vals) & ~torch.isinf(valid_vals)]
    if valid_vals.numel() == 0: return {'min': 0.0, 'max': 1.0}
    global_min = torch.min(valid_vals).item(); global_max = torch.max(valid_vals).item()
    print(f"Global Stats Calculated: Min={global_min:.4f}, Max={global_max:.4f}")
    return {'min': global_min, 'max': global_max}


# --- Plotting Functions (Revised Labeling, Added Distribution Plot) ---
# Helper to get layer type string
def get_layer_type_str(module):
    if isinstance(module, nn.Conv2d): return "Conv"
    if isinstance(module, nn.Linear): return "Line"
    return "Othe"

# Function to plot score distribution for a single layer
def plot_score_distribution(scores, layer_name, layer_type, title_prefix, save_dir):
    if scores is None or scores.numel() == 0: return
    plt.figure(figsize=(8, 5))
    scores_np = scores.cpu().numpy()
    sns.histplot(scores_np, kde=True, bins=30)
    plt.title(f"{title_prefix}: {layer_name} ({layer_type}) - Score Distribution (N={scores.numel()})")
    plt.xlabel("Normalized Importance Score")
    plt.ylabel("Frequency")
    plt.tight_layout()
    filename = f"{title_prefix.lower().replace(' ','_')}_{layer_name}_distribution.png"
    plt.savefig(os.path.join(save_dir, filename))
    plt.close()


def plot_heatmap(scores_dict, model_modules_dict, title, save_path, distribution_threshold=256):
    """Plots heatmap and score distributions for layers > threshold."""
    if not scores_dict: return
    # Sort layers naturally
    layer_names = sorted(scores_dict.keys(), key=lambda x: ([int(s) for s in re.findall(r'\d+', x)], x))
    all_scores_list = [scores_dict[name].cpu().numpy() for name in layer_names if scores_dict[name] is not None]
    if not all_scores_list: return

    # Create ytick labels with layer types
    yticklabels = []
    valid_layer_names_for_plot = [] # Store names corresponding to all_scores_list
    for name in layer_names:
         if scores_dict.get(name) is not None:
              module = model_modules_dict.get(name) # Get module using name
              layer_type_str = get_layer_type_str(module) if module else "N/A"
              yticklabels.append(f"{name} ({layer_type_str})")
              valid_layer_names_for_plot.append(name)


    # Plot Heatmap
    try:
        max_len = max(len(s) for s in all_scores_list) if all_scores_list else 0
        if max_len == 0: return
        # Only pad and plot if max_len is reasonable, otherwise heatmap is useless
        if max_len > 4096: # Arbitrary limit for usability
             print(f"Skipping heatmap for '{title}' - max channels ({max_len}) too large for effective visualization.")
        else:
            padded_scores = np.array([np.pad(s, (0, max_len - len(s)), 'constant', constant_values=np.nan) for s in all_scores_list])
            if padded_scores.size == 0: return
            plt.figure(figsize=(15, max(6, len(valid_layer_names_for_plot) * 0.4)))
            ax = sns.heatmap(padded_scores, cmap="viridis", yticklabels=yticklabels, cbar=True, annot=False, mask=np.isnan(padded_scores))
            plt.yticks(rotation=0, fontsize=8); plt.ylabel("Layer"); plt.xlabel("Channel / Neuron Index (Padded)")
            plt.title(f"Channel/Neuron Importance Scores: {title}")
            plt.tight_layout(); os.makedirs(os.path.dirname(save_path), exist_ok=True)
            plt.savefig(save_path, bbox_inches='tight'); print(f"Saved heatmap to {save_path}")
            plt.close()
    except Exception as e: print(f"Could not create heatmap for {title}: {e}")

    # Plot Distributions for wide layers
    save_dir = os.path.dirname(save_path)
    distribution_save_dir = os.path.join(save_dir, "distributions")
    os.makedirs(distribution_save_dir, exist_ok=True)
    print(f"Checking layers for distribution plots (threshold: {distribution_threshold})...")
    for name in valid_layer_names_for_plot:
        scores = scores_dict.get(name)
        if scores is not None and scores.numel() > distribution_threshold:
            module = model_modules_dict.get(name)
            layer_type_str = get_layer_type_str(module) if module else "N/A"
            plot_score_distribution(scores, name, layer_type_str, title, distribution_save_dir)


def plot_difference_heatmap(scores1_dict, scores2_dict, model_modules_dict, title, save_path, distribution_threshold=256):
    """Plots difference heatmap and score distributions for layers > threshold."""
    if not scores1_dict or not scores2_dict: return
    diff_scores_dict = {}
    common_layers = sorted(list(set(scores1_dict.keys()) & set(scores2_dict.keys())), key=lambda x: ([int(s) for s in re.findall(r'\d+', x)], x))
    if not common_layers: return

    print(f"Calculating differences for common layers: {common_layers}")
    valid_diff_layers = []
    for layer_name in common_layers:
        scores1 = scores1_dict.get(layer_name); scores2 = scores2_dict.get(layer_name)
        if scores1 is None or not isinstance(scores1, torch.Tensor) or \
           scores2 is None or not isinstance(scores2, torch.Tensor): continue
        len1, len2 = scores1.numel(), scores2.numel()
        if len1 == 0 or len2 == 0: continue
        min_len = min(len1, len2)
        if len1 != len2: scores1 = scores1.cpu()[:min_len]; scores2 = scores2.cpu()[:min_len]
        else: scores1 = scores1.cpu(); scores2 = scores2.cpu()
        diff_scores_dict[layer_name] = torch.abs(scores1 - scores2)
        valid_diff_layers.append(layer_name)

    if not diff_scores_dict: return
    layer_names = valid_diff_layers
    all_scores_list = [diff_scores_dict[name].cpu().numpy() for name in layer_names]

    # Plot Heatmap
    try:
        max_len = max(len(s) for s in all_scores_list) if all_scores_list else 0
        if max_len == 0: return
        if max_len > 4096: # Arbitrary limit for usability
             print(f"Skipping difference heatmap for '{title}' - max channels ({max_len}) too large.")
        else:
            padded_scores = np.array([np.pad(s, (0, max_len - len(s)), 'constant', constant_values=np.nan) for s in all_scores_list])
            if padded_scores.size == 0: return
            plt.figure(figsize=(15, max(6, len(layer_names) * 0.4)))
            yticklabels = [] # Regenerate labels based on valid_diff_layers
            for name in layer_names:
                module = model_modules_dict.get(name) # Get module from original dict
                layer_type_str = get_layer_type_str(module) if module else "N/A"
                yticklabels.append(f"{name} ({layer_type_str})")

            ax = sns.heatmap(padded_scores, cmap="coolwarm", center=0, yticklabels=yticklabels, cbar=True, annot=False, mask=np.isnan(padded_scores))
            plt.yticks(rotation=0, fontsize=8); plt.ylabel("Layer"); plt.xlabel("Channel / Neuron Index (Padded & Aligned)")
            plt.title(f"Importance Score Difference: {title}")
            plt.tight_layout(); os.makedirs(os.path.dirname(save_path), exist_ok=True)
            plt.savefig(save_path, bbox_inches='tight'); print(f"Saved difference heatmap to {save_path}")
            plt.close()
    except Exception as e: print(f"Could not create difference heatmap for {title}: {e}")

    # Plot Difference Score Distributions for wide layers
    save_dir = os.path.dirname(save_path)
    distribution_save_dir = os.path.join(save_dir, "distributions")
    os.makedirs(distribution_save_dir, exist_ok=True)
    print(f"Checking layers for difference distribution plots (threshold: {distribution_threshold})...")
    for name in layer_names: # Iterate through layers where diff was calculated
        diff_scores = diff_scores_dict.get(name)
        if diff_scores is not None and diff_scores.numel() > distribution_threshold:
            module = model_modules_dict.get(name) # Get original module
            layer_type_str = get_layer_type_str(module) if module else "N/A"
            plot_score_distribution(diff_scores, name, layer_type_str, f"{title} Difference", distribution_save_dir)



# --- Main Execution Logic (Epoch Range Handling) ---

def process_epoch(args, epoch, device):
    """Handles processing for a single epoch."""
    print(f"\n===== Processing Epoch {epoch} =====")
    save_sub_dir = args.model
    base_path = os.path.join('.', 'save', 'test', args.dataset, 'learning', save_sub_dir)
    global_model_path = os.path.join(base_path, 'models', f'global_epoch-{epoch}.pth')
    local_models_list_path_try1 = os.path.join(base_path, 'models', f'local_epoch-{epoch - 1}.pth')
    local_models_list_path_try2 = os.path.join(base_path, 'models', f'local_epoch-{epoch}.pth')
    local_models_list_path = next((p for p in [local_models_list_path_try1, local_models_list_path_try2] if os.path.exists(p)), None)

    local_models = []; model_modules_dict = {} # Store modules for labeling plots
    if local_models_list_path:
        try:
            print(f"Loading local models list from: {local_models_list_path}")
            loaded_object = torch.load(local_models_list_path, map_location='cpu')
            if isinstance(loaded_object, list) and all(isinstance(m, nn.Module) for m in loaded_object):
                 local_models = loaded_object; print(f"Loaded list of {len(local_models)} models.")
                 for m in local_models: m.eval()
            else: print(f"Warning: Loaded object not list of nn.Module.")
        except Exception as e: print(f"Error loading list {local_models_list_path}: {e}")

    if not local_models:
        print("Attempting to load individual client models..."); loaded_count = 0
        epoch_to_load_local = epoch - 1
        for client_idx in range(args.num_users):
             ind_local_path = os.path.join(base_path, 'models', f'local_epoch-{epoch_to_load_local}_client-{client_idx}.pth')
             if os.path.exists(ind_local_path):
                  try:
                      local_model = torch.load(ind_local_path, map_location='cpu')
                      if isinstance(local_model, nn.Module): local_model.eval(); local_models.append(local_model); loaded_count += 1
                  except Exception: pass
        if loaded_count == 0 and epoch_to_load_local != epoch: # Try epoch N
            epoch_to_load_local = epoch; print(f"Trying epoch {epoch}...")
            for client_idx in range(args.num_users):
                ind_local_path = os.path.join(base_path, 'models', f'local_epoch-{epoch_to_load_local}_client-{client_idx}.pth')
                if os.path.exists(ind_local_path):
                     try:
                         local_model = torch.load(ind_local_path, map_location='cpu')
                         if isinstance(local_model, nn.Module): local_model.eval(); local_models.append(local_model); loaded_count += 1
                     except Exception: pass
        if loaded_count < args.num_users: print(f"Warning: Loaded {loaded_count} individual models.")

    if not local_models: print(f"Error: No local models found for epoch {epoch}. Skipping."); return
    num_available_local = len(local_models)
    if num_available_local != args.num_users: print(f"Warning: Model count mismatch ({num_available_local} vs {args.num_users}).")

    if not os.path.exists(global_model_path): print(f"Error: Global model not found: {global_model_path}. Skipping."); return
    try:
        print(f"Loading global model epoch {epoch}..."); net_glob = torch.load(global_model_path, map_location='cpu'); net_glob.eval(); net_glob.to(device)
        # Store global model modules for labeling plots
        for name, module in net_glob.named_modules():
             if isinstance(module, (nn.Conv2d, nn.Linear)): model_modules_dict[name] = module
    except Exception as e: print(f"Error loading global model epoch {epoch}: {e}. Skipping."); return

    global train_loaders
    if 'train_loaders' not in globals() or not train_loaders: print("Error: train_loaders missing. Skipping."); return
    num_loaders_to_use = min(num_available_local, len(train_loaders))
    if num_loaders_to_use < num_available_local: print(f"Warning: Only {num_loaders_to_use} dataloaders available.")

    global_dataloader = get_calibration_dataloader_from_existing(args, train_loaders[:num_loaders_to_use], global_mix=True)
    local_dataloaders = [get_calibration_dataloader_from_existing(args, train_loaders, client_idx=i) for i in range(num_loaders_to_use)]
    if global_dataloader is None or len([ld for ld in local_dataloaders if ld is not None]) != num_loaders_to_use: print(f"Error creating dataloaders. Skipping."); return

    print(f"\n--- Calculating Scores for Epoch {epoch} ---")
    all_raw_scores_for_global_norm = []
    print("Calculating Global Scores (Raw)...")
    global_scores_raw = calculate_importance_scores(net_glob, global_dataloader, device, normalize_globally=False)
    if args.normalize_globally: all_raw_scores_for_global_norm.append(global_scores_raw)

    local_scores_raw_list = []
    print("Calculating Local Scores (Raw)...")
    for i in range(num_available_local):
        if i < len(local_dataloaders) and local_dataloaders[i] is not None:
            print(f"--- Client {i} ---"); client_model = local_models[i]; client_model.eval(); client_model.to(device)
            client_scores_raw = calculate_importance_scores(client_model, local_dataloaders[i], device, normalize_globally=False)
            local_scores_raw_list.append(client_scores_raw)
            if args.normalize_globally: all_raw_scores_for_global_norm.append(client_scores_raw)
            client_model.to('cpu'); del client_model; torch.cuda.empty_cache(); gc.collect()
        else: local_scores_raw_list.append({})

    global_stats = None
    if args.normalize_globally: global_stats = calculate_global_stats(all_raw_scores_for_global_norm)

    print("Applying Normalization...")
    # Use raw scores + stats/flag to get final normalized scores
    global_scores = {}
    if global_scores_raw:
         for name, score in global_scores_raw.items():
              if args.normalize_globally:
                   g_min, g_max = global_stats['min'], global_stats['max']
                   if g_max > g_min: norm_score = torch.clamp((score - g_min) / (g_max - g_min + 1e-8), 0, 1)
                   else: norm_score = score # Fallback to raw
              else: # Layer-wise
                   min_v, max_v = torch.min(score), torch.max(score)
                   if torch.isnan(min_v) or torch.isnan(max_v): norm_score = score
                   elif max_v > min_v: norm_score = (score - min_v) / (max_v - min_v + 1e-8)
                   elif max_v == min_v and max_v == 0: norm_score = torch.zeros_like(score)
                   else: norm_score = torch.ones_like(score)
              global_scores[name] = score if (torch.isnan(norm_score).any() or torch.isinf(norm_score).any()) else norm_score

    local_scores_list = []
    for i in range(num_available_local):
         client_final_scores = {}
         if i < len(local_scores_raw_list) and local_scores_raw_list[i]:
              for name, score in local_scores_raw_list[i].items():
                   if args.normalize_globally:
                        g_min, g_max = global_stats['min'], global_stats['max']
                        if g_max > g_min: norm_score = torch.clamp((score - g_min) / (g_max - g_min + 1e-8), 0, 1)
                        else: norm_score = score
                   else: # Layer-wise
                        min_v, max_v = torch.min(score), torch.max(score)
                        if torch.isnan(min_v) or torch.isnan(max_v): norm_score = score
                        elif max_v > min_v: norm_score = (score - min_v) / (max_v - min_v + 1e-8)
                        elif max_v == min_v and max_v == 0: norm_score = torch.zeros_like(score)
                        else: norm_score = torch.ones_like(score)
                   client_final_scores[name] = score if (torch.isnan(norm_score).any() or torch.isinf(norm_score).any()) else norm_score
         local_scores_list.append(client_final_scores)


    # --- Generate Visualizations ---
    output_dir = os.path.join('.', 'visualizations', 'activation', args.dataset, args.model, f'epoch_{epoch}')
    os.makedirs(output_dir, exist_ok=True)
    print(f"\nSaving visualizations for epoch {epoch} to: {output_dir}")

    dist_threshold = getattr(args, 'distribution_threshold', 256) # Get threshold from args or default

    if global_scores: plot_heatmap(global_scores, model_modules_dict, f"Global Model (Epoch {epoch})", os.path.join(output_dir, "global_importance.png"), dist_threshold)
    plotted_local_count = 0
    for client_idx in range(num_available_local):
        if client_idx < len(local_scores_list) and local_scores_list[client_idx]:
             plot_heatmap(local_scores_list[client_idx], model_modules_dict, f"Local Client {client_idx} (Epoch {epoch})", os.path.join(output_dir, f"local_client_{client_idx}_importance.png"), dist_threshold)
             plotted_local_count += 1
    print(f"Plotted heatmaps for {plotted_local_count} local clients.")

    plotted_diff_count = 0
    if global_scores:
        for client_idx in range(num_available_local):
             if client_idx < len(local_scores_list) and local_scores_list[client_idx]:
                  plot_difference_heatmap(global_scores, local_scores_list[client_idx], model_modules_dict, f"Global vs Local Client {client_idx} (Epoch {epoch})", os.path.join(output_dir, f"diff_global_vs_local_{client_idx}.png"), dist_threshold)
                  plotted_diff_count += 1
        print(f"Plotted difference heatmaps for {plotted_diff_count} local clients vs global.")
    else: print("Skipping Global vs Local difference plots - no global scores.")

    print(f"===== Finished Epoch {epoch} =====")


def main(args):
    torch.manual_seed(args.seed); torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed); random.seed(args.seed)
    device = torch.device(f"cuda:{args.gpu}" if torch.cuda.is_available() and args.gpu is not None else "cpu")
    print(f"Using device: {device}"); print(args)

    global train_loaders # Make loaders global for access within process_epoch
    print("Loading datasets (once)...")
    try:
        if hasattr(args, 'dataset_fullparti') and args.dataset_fullparti: loaded_data = init_data(args)
        else: loaded_data = init_data_methodone(args)
        if not isinstance(loaded_data, (list, tuple)) or len(loaded_data) < 2: raise ValueError("init_data structure error")
        train_loaders, _ = loaded_data[0], loaded_data[1]
        if not isinstance(train_loaders, list) or not train_loaders: raise ValueError("Invalid training loaders")
        print(f"Loaded {len(train_loaders)} training data loaders.")
    except Exception as e: print(f"Error dataset initialization: {e}. Exiting."); return

    epoch_str = str(args.epoch) # Ensure string
    epochs_to_process = []
    if '-' in epoch_str:
        try:
            start_epoch, end_epoch = map(int, epoch_str.split('-')); assert start_epoch <= end_epoch
            epochs_to_process = list(range(start_epoch, end_epoch + 1))
            print(f"Processing epoch range: {start_epoch}-{end_epoch}")
        except Exception: print(f"Error: Invalid epoch range '{epoch_str}'."); return
    else:
        try: epochs_to_process = [int(epoch_str)]; print(f"Processing single epoch: {epoch_str}")
        except Exception: print(f"Error: Invalid epoch format '{epoch_str}'."); return

    for epoch in epochs_to_process:
        process_epoch(args, epoch, device)

    print(f"\nVisualization process finished for all specified epochs.")

if __name__ == "__main__":
    args = None
    if 'project_args_parser' in globals() and project_args_parser:
        print("Using project's args_parser.")
        try: args = project_args_parser()
        except Exception as e: print(f"Error calling project_args_parser: {e}. Using fallback."); args = None
    if args is None:
        print("Warning: Using fallback argparse definition.")
        parser = argparse.ArgumentParser()
        # Essential Args
        parser.add_argument('--dataset', type=str, default='domain_digits', choices=['domain_digits', 'office-caltech10', 'DomainNet'])
        parser.add_argument('--model', type=str, default='cnn', choices=['cnn', 'vgg16', 'resnet18', 'digit'])
        parser.add_argument('--num_users', type=int, default=5)
        parser.add_argument('--epoch', type=str, required=True, help="Epoch number (e.g., '10') or range (e.g., '5-10') to process") # Changed type to str
        parser.add_argument('--gpu', type=int, default=0, help="GPU ID, -1 for CPU")
        parser.add_argument('--seed', type=int, default=42)
        # Args potentially needed by user's functions (add more if needed)
        parser.add_argument('--batch_size', type=int, default=32, help="Batch size for calibration loader")
        parser.add_argument('--num_classes', type=int, default=10, help='Number of classes')
        parser.add_argument('--data_dir', type=str, default='./data/', help='Data directory')
        parser.add_argument('--iid', action='store_true', help='Assume iid data distribution')
        parser.set_defaults(iid=True)
        parser.add_argument('--dataset_fullparti', default=True, help="if data full participant (default: False)")
        parser.add_argument('--n_train', type=int, default=450, help="num of train set")
        # Add vgg_depth if user wants to control VGG variant via args
        parser.add_argument('--vgg_depth', type=int, default=16, choices=[11, 13, 16, 19], help='VGG depth')
        # Add flag for global normalization experiment
        parser.add_argument('--normalize_globally', action='store_true', help='Normalize scores globally across all layers/models instead of layer-wise')
        parser.add_argument('--percent', type = float, default= 1, help ='percentage of dataset to train')

        # Add other args from user's snippet if they are needed by init_data/model
        parser.add_argument('--local_ep', type=int, default=10)
        parser.add_argument('--local_bs', type=int, default=64)
        parser.add_argument('--lr', type=float, default=0.01)
        parser.add_argument('--momentum', type=float, default=0.9)
        parser.add_argument('--verify', type=str, default='normal', choices=['normal', 'marker', 'backdoor'],help="Verification Methods")
        # Add any other necessary args from the user's full parser
        args = parser.parse_args()

    if args.gpu == -1: args.gpu = None
    main(args)
