import argparse
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
import os
from pathlib import Path
import random
import seaborn as sns

import models
from util.datasets import build_dataset

# Global dictionary to store distributions from each step
activation_dict = {}


# Function to create hook for capturing data at each step
def make_hook(step_name):
    def hook(module, input, output):
        # Capture the entire tensor output
        if isinstance(output, torch.Tensor):
            # Store the full batch for visualization
            activation_dict[step_name] = output.detach().cpu()

    return hook


# Register hooks for all modules in sequential order
def register_sequential_hooks(model):
    """Register hooks for modules in the forward pass order"""
    import models

    # Define the list of modules to track in order
    module_order = [
        # 1. First downsampling module
        "downsample1_1",
        # 2. First convolution block
        "ConvBlock1_1.0.Conv.spike1",
        "ConvBlock1_1.0.Conv.pwconv1",
        "ConvBlock1_1.0.Conv.spike2",
        "ConvBlock1_1.0.Conv.dwconv",
        "ConvBlock1_1.0.Conv.spike3",
        "ConvBlock1_1.0.Conv.pwconv2",
        "ConvBlock1_1.0.spike1",
        "ConvBlock1_1.0.conv1",
        "ConvBlock1_1.0.bn1",
        "ConvBlock1_1.0.spike2",
        "ConvBlock1_1.0.conv2",
        "ConvBlock1_1.0.bn2",
        # 3. Second downsampling
        "downsample1_2",
        # 4. Second convolution block
        "ConvBlock1_2.0.Conv.spike1",
        "ConvBlock1_2.0.Conv.pwconv1",
        "ConvBlock1_2.0.Conv.spike2",
        "ConvBlock1_2.0.Conv.dwconv",
        "ConvBlock1_2.0.Conv.spike3",
        "ConvBlock1_2.0.Conv.pwconv2",
        "ConvBlock1_2.0.spike1",
        "ConvBlock1_2.0.conv1",
        "ConvBlock1_2.0.bn1",
        "ConvBlock1_2.0.spike2",
        "ConvBlock1_2.0.conv2",
        "ConvBlock1_2.0.bn2",
        # 5. Third downsampling
        "downsample2",
        # 6. Third convolution block
        "ConvBlock2_1.0.Conv.spike1",
        "ConvBlock2_1.0.Conv.pwconv1",
        "ConvBlock2_1.0.Conv.spike2",
        "ConvBlock2_1.0.Conv.dwconv",
        "ConvBlock2_1.0.Conv.spike3",
        "ConvBlock2_1.0.Conv.pwconv2",
        "ConvBlock2_1.0.spike1",
        "ConvBlock2_1.0.conv1",
        "ConvBlock2_1.0.bn1",
        "ConvBlock2_1.0.spike2",
        "ConvBlock2_1.0.conv2",
        "ConvBlock2_1.0.bn2",
        # 7. Fourth convolution block
        "ConvBlock2_2.0.Conv.spike1",
        "ConvBlock2_2.0.Conv.pwconv1",
        "ConvBlock2_2.0.Conv.spike2",
        "ConvBlock2_2.0.Conv.dwconv",
        "ConvBlock2_2.0.Conv.spike3",
        "ConvBlock2_2.0.Conv.pwconv2",
        "ConvBlock2_2.0.spike1",
        "ConvBlock2_2.0.conv1",
        "ConvBlock2_2.0.bn1",
        "ConvBlock2_2.0.spike2",
        "ConvBlock2_2.0.conv2",
        "ConvBlock2_2.0.bn2",
        # 8. Fourth downsampling
        "downsample3",
    ]

    # Add the 6 blocks in block3
    for i in range(6):
        prefix = f"block3.{i}"
        module_order.extend([
            f"{prefix}.conv.spike1",
            f"{prefix}.conv.pwconv1",
            f"{prefix}.conv.spike2",
            f"{prefix}.conv.dwconv",
            f"{prefix}.conv.spike3",
            f"{prefix}.conv.pwconv2",
            f"{prefix}.attn.head_spike",
            f"{prefix}.attn.q_conv",
            f"{prefix}.attn.q_spike",
            f"{prefix}.attn.k_conv",
            f"{prefix}.attn.k_spike",
            f"{prefix}.attn.v_conv",
            f"{prefix}.attn.v_spike",
            f"{prefix}.attn.attn_spike",
            f"{prefix}.attn.proj_conv",
            f"{prefix}.mlp.fc1_spike",
            f"{prefix}.mlp.fc1_conv",
            f"{prefix}.mlp.fc1_bn",
            f"{prefix}.mlp.fc2_spike",
            f"{prefix}.mlp.fc2_conv",
            f"{prefix}.mlp.fc2_bn"
        ])

    # Add downsample4 and the 2 blocks in block4
    module_order.append("downsample4")
    for i in range(2):
        prefix = f"block4.{i}"
        module_order.extend([
            f"{prefix}.conv.spike1",
            f"{prefix}.conv.pwconv1",
            f"{prefix}.conv.spike2",
            f"{prefix}.conv.dwconv",
            f"{prefix}.conv.spike3",
            f"{prefix}.conv.pwconv2",
            f"{prefix}.attn.head_spike",
            f"{prefix}.attn.q_conv",
            f"{prefix}.attn.q_spike",
            f"{prefix}.attn.k_conv",
            f"{prefix}.attn.k_spike",
            f"{prefix}.attn.v_conv",
            f"{prefix}.attn.v_spike",
            f"{prefix}.attn.attn_spike",
            f"{prefix}.attn.proj_conv",
            f"{prefix}.mlp.fc1_spike",
            f"{prefix}.mlp.fc1_conv",
            f"{prefix}.mlp.fc1_bn",
            f"{prefix}.mlp.fc2_spike",
            f"{prefix}.mlp.fc2_conv",
            f"{prefix}.mlp.fc2_bn"
        ])

    # Add final head and spike
    module_order.extend(["head", "spike"])

    # Track registered modules
    registered_modules = set()
    hook_count = 0

    # Register hooks for modules in the specified path
    for module_path in module_order:
        try:
            # Get the module
            parts = module_path.split('.')
            curr_module = model

            for part in parts:
                if part.isdigit():
                    curr_module = curr_module[int(part)]
                else:
                    curr_module = getattr(curr_module, part)

            # For Sequential type, register hooks for the sequence and its submodules
            if isinstance(curr_module, nn.Sequential):
                # Register a hook for the entire Sequential
                module_id = id(curr_module)
                if module_id not in registered_modules:
                    curr_module.register_forward_hook(make_hook(f"seq_{module_path}"))
                    registered_modules.add(module_id)
                    hook_count += 1

                # Register hooks for each submodule in the Sequential
                for i, submodule in enumerate(curr_module):
                    submodule_id = id(submodule)
                    if submodule_id not in registered_modules:
                        submodule_path = f"{module_path}.{i}"
                        submodule.register_forward_hook(make_hook(f"seq_item_{submodule_path}"))
                        registered_modules.add(submodule_id)
                        hook_count += 1
            else:
                # Register hook for regular module
                module_id = id(curr_module)
                if module_id not in registered_modules:
                    curr_module.register_forward_hook(make_hook(module_path))
                    registered_modules.add(module_id)
                    hook_count += 1

        except (AttributeError, IndexError) as e:
            print(f"Could not find module {module_path}: {e}")

    print(f"Registered {hook_count} hooks in sequential order")

    # Add detailed hooks for MS_Attention_linear instances
    modify_ms_attention_linear()


# Modify MS_Attention_linear to capture intermediate steps
def modify_ms_attention_linear():
    from models import MS_Attention_linear

    # Store the original forward method
    original_forward = MS_Attention_linear.forward

    # Define the new forward method with additional hooks
    def new_forward(self, x):
        B, C, H, W = x.shape
        N = H * W
        C_v = int(C * self.lamda_ratio)

        # Get layer ID for storing intermediate results
        layer_id = id(self) % 1000  # Use part of object id as unique identifier
        layer_prefix = f"attn_details{layer_id}_"

        # Apply head_spike
        x = self.head_spike(x)

        # Apply convolutions
        q = self.q_conv(x)
        k = self.k_conv(x)
        v = self.v_conv(x)

        # Store convolution outputs
        activation_dict[f"{layer_prefix}q_conv_out"] = q.detach().cpu()
        activation_dict[f"{layer_prefix}k_conv_out"] = k.detach().cpu()
        activation_dict[f"{layer_prefix}v_conv_out"] = v.detach().cpu()

        # Apply spike functions
        q = self.q_spike(q)
        k = self.k_spike(k)

        # Store spike outputs
        activation_dict[f"{layer_prefix}q_spike_out"] = q.detach().cpu()
        activation_dict[f"{layer_prefix}k_spike_out"] = k.detach().cpu()

        # Reshape operations
        q = q.flatten(2)
        q = (
            q.transpose(-1, -2)
            .reshape(B, N, self.num_heads, C // self.num_heads)
            .permute(0, 2, 1, 3)
            .contiguous()
        )

        k = k.flatten(2)
        k = (
            k.transpose(-1, -2)
            .reshape(B, N, self.num_heads, C // self.num_heads)
            .permute(0, 2, 1, 3)
            .contiguous()
        )

        v = v.flatten(2)
        v = (
            v.transpose(-1, -2)
            .reshape(B, N, self.num_heads, C_v // self.num_heads)
            .permute(0, 2, 1, 3)
            .contiguous()
        )

        # Store reshaped q, k, v
        activation_dict[f"{layer_prefix}q_reshaped"] = q.detach().cpu()
        activation_dict[f"{layer_prefix}k_reshaped"] = k.detach().cpu()
        activation_dict[f"{layer_prefix}v_reshaped"] = v.detach().cpu()

        # Scale q and k
        q_scaled = q * 4.0
        k_scaled = k * 4.0

        # Store scaled q and k
        activation_dict[f"{layer_prefix}q_scaled"] = q_scaled.detach().cpu()
        activation_dict[f"{layer_prefix}k_scaled"] = k_scaled.detach().cpu()

        # Calculate attention weights
        attn = q_scaled @ k_scaled.transpose(-2, -1)

        # Store raw attention weights
        activation_dict[f"{layer_prefix}attn_raw"] = attn.detach().cpu()

        # Apply softmax
        from models import fp16_optimized_exp2_softmax
        x = fp16_optimized_exp2_softmax(attn)

        # Store softmax output
        activation_dict[f"{layer_prefix}attn_softmax"] = x.detach().cpu()

        # Apply attention to values
        x = x @ v

        # Store attention output
        activation_dict[f"{layer_prefix}attn_output"] = x.detach().cpu()

        # Reshape and apply subsequent processing
        x = x.transpose(2, 3).reshape(B, C_v, N).contiguous()

        # Store transposed result
        activation_dict[f"{layer_prefix}reshaped_before_spike"] = x.detach().cpu()

        x = self.attn_spike(x)

        # Store post-spike result
        activation_dict[f"{layer_prefix}after_attn_spike"] = x.detach().cpu()

        x = x.reshape(B, C_v, H, W)

        # Store reshaped spatial result
        activation_dict[f"{layer_prefix}reshaped_spatial"] = x.detach().cpu()

        x = self.proj_conv(x).reshape(B, C, H, W)

        # Store final output
        activation_dict[f"{layer_prefix}final_output"] = x.detach().cpu()

        return x

    # Replace the forward method
    MS_Attention_linear.forward = new_forward
    print("Modified MS_Attention_linear.forward to capture intermediate steps")


# Plot distribution graphs
def plot_distributions(output_dir):
    os.makedirs(output_dir, exist_ok=True)

    # Ordered directory structure
    ordered_structure = [
        "01_downsample1_1",  # First downsampling
        "02_convblock1_1",  # First convolution block
        "03_downsample1_2",  # Second downsampling
        "04_convblock1_2",  # Second convolution block
        "05_downsample2",  # Third downsampling
        "06_convblock2_1",  # Third convolution block
        "07_convblock2_2",  # Fourth convolution block
        "08_downsample3",  # Fourth downsampling
        "09_block3",  # block3 (contains 6 submodules)
        "10_downsample4",  # Fifth downsampling
        "11_block4",  # block4 (contains 2 submodules)
        "12_head",  # Classifier head
        "other"  # Other modules
    ]

    # Create directories for each type
    for folder in ordered_structure:
        os.makedirs(os.path.join(output_dir, folder), exist_ok=True)

    # Create distribution index file
    index_file = os.path.join(output_dir, "distribution_index.txt")
    with open(index_file, 'w', encoding='utf-8') as f:
        f.write("Network Layer Distribution Visualization Index\n")
        f.write("=" * 50 + "\n\n")
        f.write("Categorized by network layer structure:\n\n")

        # Create counters for each directory
        folder_counters = {folder: 0 for folder in ordered_structure}

        # All keys - use predefined ordering instead of simple alphabetical sorting
        all_keys = list(activation_dict.keys())

        # Calculate counts for each type in advance
        for key in all_keys:
            data = activation_dict[key]
            if data is None or data.numel() == 0:
                continue

            # Determine which folder to place in
            target_folder = "other"
            if "downsample1_1" in key:
                target_folder = "01_downsample1_1"
            elif "ConvBlock1_1" in key:
                target_folder = "02_convblock1_1"
            elif "downsample1_2" in key:
                target_folder = "03_downsample1_2"
            elif "ConvBlock1_2" in key:
                target_folder = "04_convblock1_2"
            elif "downsample2" in key:
                target_folder = "05_downsample2"
            elif "ConvBlock2_1" in key:
                target_folder = "06_convblock2_1"
            elif "ConvBlock2_2" in key:
                target_folder = "07_convblock2_2"
            elif "downsample3" in key:
                target_folder = "08_downsample3"
            elif "block3" in key:
                target_folder = "09_block3"
            elif "downsample4" in key:
                target_folder = "10_downsample4"
            elif "block4" in key:
                target_folder = "11_block4"
            elif "head" in key or "spike" in key and not any(
                    x in key for x in ["head_spike", "attn_spike", "q_spike", "k_spike", "v_spike"]):
                target_folder = "12_head"

            folder_counters[target_folder] += 1

        # Write directory statistics to index file
        for folder in ordered_structure:
            f.write(f"{folder}: {folder_counters[folder]} distributions\n")

        f.write("\nComplete distribution list:\n\n")

    # Statistics for different value ranges
    total_distributions = 0
    zero_heavy_distributions = 0
    negative_distributions = 0

    # Process each activation distribution
    for key in all_keys:
        data = activation_dict[key]
        if data is None or data.numel() == 0:
            print(f"Skipping {key} - no data")
            continue

        total_distributions += 1

        # Determine which folder to place in
        target_folder = "other"
        if "downsample1_1" in key:
            target_folder = "01_downsample1_1"
        elif "ConvBlock1_1" in key:
            target_folder = "02_convblock1_1"
        elif "downsample1_2" in key:
            target_folder = "03_downsample1_2"
        elif "ConvBlock1_2" in key:
            target_folder = "04_convblock1_2"
        elif "downsample2" in key:
            target_folder = "05_downsample2"
        elif "ConvBlock2_1" in key:
            target_folder = "06_convblock2_1"
        elif "ConvBlock2_2" in key:
            target_folder = "07_convblock2_2"
        elif "downsample3" in key:
            target_folder = "08_downsample3"
        elif "block3" in key:
            target_folder = "09_block3"
        elif "downsample4" in key:
            target_folder = "10_downsample4"
        elif "block4" in key:
            target_folder = "11_block4"
        elif "head" in key or "spike" in key and not any(
                x in key for x in ["head_spike", "attn_spike", "q_spike", "k_spike", "v_spike"]):
            target_folder = "12_head"

        # Create clean filename
        filename = key.replace('.', '_').replace('/', '_')

        # Create figure
        plt.figure(figsize=(12, 8))

        # Create histogram subplot
        plt.subplot(2, 1, 1)

        # Flatten tensor for histogram
        flat_data = data.flatten().numpy()

        # Check distribution characteristics
        if np.sum(flat_data == 0) / len(flat_data) > 0.5:
            zero_heavy_distributions += 1

        if np.min(flat_data) < 0:
            negative_distributions += 1

        # Sample very large datasets
        if len(flat_data) > 10000000:  # If more than 10 million values
            print(f"Sampling {key} data for visualization (very large tensor)")
            indices = np.random.choice(len(flat_data), size=10000000, replace=False)
            flat_data = flat_data[indices]

        # Plot histogram
        sns.histplot(flat_data, bins=100, kde=True)
        plt.title(f'Single Batch Distribution - {key}')
        plt.xlabel('Value')
        plt.ylabel('Frequency')

        # Create statistics subplot
        plt.subplot(2, 1, 2)
        plt.axis('off')

        # Calculate statistics
        mean_val = np.mean(flat_data)
        median_val = np.median(flat_data)
        min_val = np.min(flat_data)
        max_val = np.max(flat_data)
        std_val = np.std(flat_data)

        # Count special values
        zero_count = np.sum(flat_data == 0)
        zero_percent = (zero_count / len(flat_data)) * 100

        # Check for NaN or Inf values
        nan_count = np.sum(np.isnan(flat_data))
        inf_count = np.sum(np.isinf(flat_data))

        # Calculate percentiles
        percentiles = [1, 5, 25, 50, 75, 95, 99]
        percentile_values = np.percentile(flat_data, percentiles)

        # Create statistics text
        stats_text = f"""
        Statistics - {key}:

        Shape: {data.shape}
        Mean: {mean_val:.6f}
        Median: {median_val:.6f}
        Min: {min_val:.6f}
        Max: {max_val:.6f}
        Std Dev: {std_val:.6f}

        Zero count: {zero_count} ({zero_percent:.2f}%)
        NaN values: {nan_count}
        Inf values: {inf_count}

        Percentiles:
        """

        for p, val in zip(percentiles, percentile_values):
            stats_text += f"    {p}%: {val:.6f}\n"

        plt.text(0.1, 0.5, stats_text, fontsize=12, family='monospace')

        # Save figure
        save_path = os.path.join(output_dir, target_folder, f"{filename}_distribution.png")
        plt.tight_layout()
        plt.savefig(save_path, dpi=300)
        plt.close()

        # Append to index file
        with open(index_file, 'a', encoding='utf-8') as f:
            rel_path = os.path.join(target_folder, f"{filename}_distribution.png")
            f.write(f"{key}: {rel_path}\n")

        print(f"Saved distribution graph: {key} -> {target_folder} directory")

        # Clear data to free memory
        activation_dict[key] = None

    # Add analysis summary to index file
    with open(index_file, 'a', encoding='utf-8') as f:
        f.write("\n\nDistribution Analysis Summary\n")
        f.write("=" * 50 + "\n\n")
        f.write(f"Total distributions: {total_distributions}\n")
        if total_distributions > 0:
            f.write(
                f"Distributions with >50% zero values: {zero_heavy_distributions} ({zero_heavy_distributions / total_distributions * 100:.2f}%)\n")
            f.write(
                f"Distributions with negative values: {negative_distributions} ({negative_distributions / total_distributions * 100:.2f}%)\n")

    print(f"\nDistribution Statistics Summary:")
    print(f"Total distributions: {total_distributions}")
    if total_distributions > 0:
        print(
            f"Distributions with >50% zero values: {zero_heavy_distributions} ({zero_heavy_distributions / total_distributions * 100:.2f}%)")
        print(
            f"Distributions with negative values: {negative_distributions} ({negative_distributions / total_distributions * 100:.2f}%)")


def get_args():
    parser = argparse.ArgumentParser('Distribution Visualization Script', add_help=False)

    # Model parameters
    parser.add_argument('--model', default='Efficient_Spiking_Transformer_l', type=str, metavar='MODEL',
                        help='Model name to train')
    parser.add_argument('--model_mode', default='ms', type=str, help='Mode of model to train')
    parser.add_argument('--input_size', default=224, type=int, help='Input image size')
    parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT', help='Drop path rate')

    # Dataset parameters
    parser.add_argument('--data_path', default='/dev/shm/imagenet-zdh/ImageNet-1K', type=str, help='Dataset path')
    parser.add_argument('--nb_classes', default=1000, type=int, help='Number of classes')
    parser.add_argument('--batch_size', default=1, type=int, help='Batch size - only process one batch')
    parser.add_argument('--num_workers', default=4, type=int)
    parser.add_argument('--pin_mem', action='store_true', help='Pin CPU memory in DataLoader')
    parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
    parser.set_defaults(pin_mem=True)

    # Output parameters
    parser.add_argument('--output_dir', default='./single_batch_distribution', help='Path to save outputs')

    # Checkpoint parameters
    parser.add_argument('--resume', default='', help='Resume from checkpoint', required=True)

    # Device parameters
    parser.add_argument('--device', default='cuda', help='Device to use for training/testing')
    parser.add_argument('--seed', default=0, type=int)

    # Distributed parameters
    parser.add_argument('--world_size', default=1, type=int, help='Number of distributed processes')
    parser.add_argument('--local_rank', default=-1, type=int)
    parser.add_argument('--dist_on_itp', action='store_true')
    parser.add_argument('--dist_url', default='env://', help='URL used to set up distributed training')

    # Time steps parameter
    parser.add_argument('--time_steps', default=1, type=int)

    return parser.parse_args()


def main():
    args = get_args()

    print("\n" + "=" * 50)
    print("Single Batch Distribution Visualization Tool - Efficient_Spiking_Transformer_l")
    print("=" * 50 + "\n")

    print("Working directory: {}".format(os.path.dirname(os.path.realpath(__file__))))
    print("Parameters:\n{}".format(args).replace(", ", ",\n"))

    device = torch.device(args.device)

    # Fix seed for reproducibility
    seed = args.seed
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    # Check if CUDA is available
    if args.device == 'cuda' and not torch.cuda.is_available():
        print("Warning: CUDA requested but not available. Switching to CPU.")
        device = torch.device('cpu')
        args.device = 'cpu'

    # Create output directory
    if args.output_dir:
        Path(args.output_dir).mkdir(parents=True, exist_ok=True)
        print(f"Will save distribution graphs to: {os.path.abspath(args.output_dir)}")

    # Build dataset
    print("Building dataset...")
    dataset_val = build_dataset(is_train=False, args=args)
    sampler_val = torch.utils.data.SequentialSampler(dataset_val)

    # Use appropriate batch size
    actual_batch_size = min(args.batch_size, 4)  # Limit batch size to avoid CUDA memory errors
    print(f"Using batch size {actual_batch_size} for visualization")

    data_loader_val = torch.utils.data.DataLoader(
        dataset_val,
        sampler=sampler_val,
        batch_size=actual_batch_size,
        num_workers=args.num_workers,
        pin_memory=args.pin_mem,
        drop_last=False,
    )

    # Create model
    print(f"Creating model: {args.model}")
    model = models.__dict__[args.model]()
    model.T = args.time_steps

    # Print model structure
    print("Model structure:")
    print(model)

    # Load checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print(f"Loading checkpoint: {args.resume}")
            checkpoint = torch.load(args.resume, map_location='cpu')

            # Handle different checkpoint formats
            if 'model' in checkpoint:
                checkpoint_model = checkpoint['model']
            elif 'state_dict' in checkpoint:
                checkpoint_model = checkpoint['state_dict']
            else:
                checkpoint_model = checkpoint

            # Load model weights
            msg = model.load_state_dict(checkpoint_model, strict=False)
            print(f"Checkpoint loaded: {msg}")
        else:
            print(f"No checkpoint found at: {args.resume}")
            return

    model.to(device)
    model.eval()

    # Register hooks in sequential order
    register_sequential_hooks(model)

    # Process a single batch
    try:
        print("Getting single batch of images...")
        data_iter = iter(data_loader_val)
        images, _ = next(data_iter)

        # Move images to device
        images = images.to(device)

        print(f"Performing forward pass on a single batch ({len(images)} images)...")
        # Forward pass to capture distributions
        with torch.no_grad():
            # Clear CUDA cache to ensure enough memory
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

            # Perform model inference
            print("Starting model inference...")
            model(images)
            print("Model inference complete, all intermediate activations captured")

            # Print number of captured activations
            print(f"Captured {len(activation_dict)} different activation distributions")

        print("Forward pass complete. Plotting distribution graphs...")
        # Plot and save distribution visualizations
        plot_distributions(args.output_dir)

        print(f"Distribution analysis complete! All graphs saved to {args.output_dir} directory")

    except RuntimeError as e:
        if "CUDA out of memory" in str(e) or "CUBLAS_STATUS_ALLOC_FAILED" in str(e):
            print("CUDA out of memory! Try reducing batch size and run again")
            print(f"Error details: {e}")
        else:
            print(f"Runtime error: {e}")
        import traceback
        traceback.print_exc()
    except Exception as e:
        print(f"Error during processing: {e}")
        import traceback
        traceback.print_exc()


if __name__ == '__main__':
    main()