"""
Effective Receptive Field Analysis for FluxNet_D Models

This script analyzes the effective receptive field (ERF) of FluxNet_D models
trained with different temporal resolutions (10dt, 100dt, 1000dt).

The analysis:
1. Computes gradients of intermediate feature maps (raw_fluxes) with respect to input
2. Identifies effective receptive field using 1% threshold
3. Supports parallel processing across channels for efficiency
4. Outputs quantitative metrics (theoretical vs effective RF) as markdown table

Output:
- NPY files for each channel's RF map
- Statistics text file
- Markdown table with RF size comparison
"""

import os
import sys
import torch
import torch.nn as nn
import numpy as np
import random
from torch.autograd import grad
from concurrent.futures import ThreadPoolExecutor, as_completed
import threading

# Add FluxNet source path
sys.path.insert(0, '/home/ml4pf/zshlan/FluxNet/src')

# ==================== Configuration ====================
MODEL_PATHS = {
    '10dt': '/home/ml4pf/zshlan/FluxNet/results/spinodal_decomposition/ablation_10dt/FluxNet_D_pf/best_model.pt',
    '100dt': '/home/ml4pf/zshlan/FluxNet/results/spinodal_decomposition/ablation_100dt/FluxNet_D_pf/best_model.pt',
    '1000dt': '/home/ml4pf/zshlan/FluxNet/results/spinodal_decomposition/ablation_1000dt/FluxNet_D_pf/best_model.pt',
}

MODEL_CONFIGS = {
    '10dt': {'base_channels': 32, 'num_blocks': 4, 'kernel_size': 3, 'neighborhood_size': 3},
    '100dt': {'base_channels': 32, 'num_blocks': 4, 'kernel_size': 5, 'neighborhood_size': 5},
    '1000dt': {'base_channels': 32, 'num_blocks': 6, 'kernel_size': 7, 'neighborhood_size': 9},
}

# Analysis parameters
THRESHOLD = 0.01  # 1% threshold for effective RF
NUM_SAMPLE_POINTS = 100
IMAGE_SIZE = 256  # Size of test image
RANDOM_SEED = 666

# Output directory
OUTPUT_DIR = '/home/ml4pf/zshlan/FluxNet/experiments/spinodal_decomposition/analysis_erf'


def setup_seed(seed):
    """Set random seeds for reproducibility"""
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True


def load_model(checkpoint_path, model_config, device):
    """Load a FluxNet_D model to specified device"""
    from models.fluxnet_d_2d import FluxNet_D

    model = FluxNet_D(
        in_channels=1,
        base_channels=model_config['base_channels'],
        num_blocks=model_config['num_blocks'],
        kernel_size=model_config['kernel_size'],
        act_fn=nn.GELU,
        norm_2d=nn.BatchNorm2d,
        neighborhood_size=model_config['neighborhood_size']
    )
    model.to(device)

    if os.path.exists(checkpoint_path):
        model.load_state_dict(torch.load(checkpoint_path, map_location=device, weights_only=False))
    else:
        raise FileNotFoundError(f"Model file not found: {checkpoint_path}")

    model.eval()
    return model


def compute_theoretical_receptive_field(model):
    """
    Compute theoretical receptive field size for FluxNet_D model.

    The theoretical RF is determined by:
    - First conv layer kernel size
    - Each residual block with two conv layers
    """
    num_blocks = model.num_blocks

    # Get kernel size from first conv in a residual block
    kernel_size = model.res_blocks[0][0].conv[1].kernel_size[0]

    # Initial layer's receptive field
    receptive_field_size = kernel_size

    # Each residual block adds 2*(kernel_size-1) to the receptive field
    for _ in range(num_blocks):
        receptive_field_size += 2 * (kernel_size - 1)

    return receptive_field_size


def get_channel_names(model):
    """
    Get names for all channels in raw_fluxes output.

    Channel structure:
    - Outflow percentage (sigmoid): 1 channel
    - Outflow distribution (softmax): num_neighbors channels
    - Inflow percentage (sigmoid): 1 channel
    - Inflow distribution (softmax): num_neighbors channels
    """
    num_neighbors = model.num_neighbors
    radius = model.neighborhood_size // 2

    channel_names = []

    # Outflow percentage channel
    channel_names.append("outflow_percentage_sigmoid")

    # Outflow distribution channels (softmax)
    for i in range(-radius, radius + 1):
        for j in range(-radius, radius + 1):
            if i != 0 or j != 0:
                channel_names.append(f"outflow_dist_softmax_dy{i:+d}_dx{j:+d}")

    # Inflow percentage channel
    channel_names.append("inflow_percentage_sigmoid")

    # Inflow distribution channels (softmax)
    for i in range(-radius, radius + 1):
        for j in range(-radius, radius + 1):
            if i != 0 or j != 0:
                channel_names.append(f"inflow_dist_softmax_dy{i:+d}_dx{j:+d}")

    return channel_names


def compute_gradient_for_channel(model, input_tensor, output_point, channel_idx):
    """
    Compute gradient of specific raw_fluxes channel at output_point w.r.t. input.
    """
    input_tensor = input_tensor.clone().detach().requires_grad_(True)

    # Forward pass through model backbone to get raw_fluxes
    x = input_tensor
    features = model.first_conv(x)

    for main_path, fusion_conv in model.res_blocks:
        identity = features
        features = main_path(features)
        features = torch.cat([features, identity], dim=1)
        features = fusion_conv(features)

    # Get raw_fluxes output
    raw_fluxes = model.flux_conv(features)

    # Select specific channel and location
    y, x_coord = output_point
    target_output = raw_fluxes[0, channel_idx, y, x_coord]

    # Compute gradient
    input_gradient = grad(target_output, input_tensor, retain_graph=True)[0]

    return input_gradient


def calculate_effective_rf_size(grad_map, threshold=0.01):
    """
    Calculate effective receptive field size.

    Args:
        grad_map: Normalized gradient map
        threshold: Pixels with gradient > threshold * max are considered part of ERF

    Returns:
        Number of pixels in the effective receptive field
    """
    center_val = grad_map.max()
    thresh_val = center_val * threshold
    rf_size = np.sum(grad_map > thresh_val)
    return rf_size


def analyze_single_channel(model, input_tensor, points, channel_idx,
                           theoretical_rf, rf_radius, h, w, threshold):
    """
    Analyze effective receptive field for a single channel.
    """
    accumulated_rf = np.zeros((theoretical_rf, theoretical_rf))
    rf_sizes = []

    for (y, x) in points:
        # Compute gradient
        input_grad = compute_gradient_for_channel(model, input_tensor, (y, x), channel_idx)
        grad_magnitude = torch.abs(input_grad[0, 0]).detach().cpu().numpy()

        # Normalize gradient
        if grad_magnitude.max() != 0:
            grad_normalized = grad_magnitude / grad_magnitude.max()
        else:
            grad_normalized = grad_magnitude

        # Extract receptive field region
        y_min = max(0, y - rf_radius)
        y_max = min(h, y + rf_radius + 1)
        x_min = max(0, x - rf_radius)
        x_max = min(w, x + rf_radius + 1)

        rf_region = grad_normalized[y_min:y_max, x_min:x_max]

        # Place RF region in centered theoretical RF
        rf_centered = np.zeros((theoretical_rf, theoretical_rf))

        y_offset = rf_radius - (y - y_min)
        x_offset = rf_radius - (x - x_min)

        y_rf_min = max(0, y_offset)
        y_rf_max = min(theoretical_rf, y_offset + rf_region.shape[0])
        x_rf_min = max(0, x_offset)
        x_rf_max = min(theoretical_rf, x_offset + rf_region.shape[1])

        y_reg_min = max(0, -y_offset)
        y_reg_max = y_reg_min + (y_rf_max - y_rf_min)
        x_reg_min = max(0, -x_offset)
        x_reg_max = x_reg_min + (x_rf_max - x_rf_min)

        rf_centered[y_rf_min:y_rf_max, x_rf_min:x_rf_max] = rf_region[y_reg_min:y_reg_max, x_reg_min:x_reg_max]

        accumulated_rf += rf_centered

        # Calculate effective RF size
        rf_size = calculate_effective_rf_size(rf_centered, threshold)
        rf_sizes.append(rf_size)

    # Average RF
    avg_rf = accumulated_rf / len(points)
    avg_rf_normalized = avg_rf / avg_rf.max() if avg_rf.max() > 0 else avg_rf

    # Calculate statistics
    avg_rf_size = np.mean(rf_sizes)
    ratio = avg_rf_size / (theoretical_rf * theoretical_rf) * 100
    avg_rf_size_sqrt = np.sqrt(avg_rf_size)

    return {
        'avg_rf': avg_rf_normalized,
        'avg_rf_size': avg_rf_size_sqrt,
        'ratio': ratio
    }


def analyze_model(model_name, model_path, model_config, device):
    """
    Analyze effective receptive field for a single model.
    """
    print(f"\n{'='*80}")
    print(f"Analyzing model: {model_name}")
    print(f"{'='*80}")

    # Create output directory for this model
    model_output_dir = os.path.join(OUTPUT_DIR, f'erf_results_{model_name}')
    os.makedirs(model_output_dir, exist_ok=True)

    # Load model
    model = load_model(model_path, model_config, device)
    model = model.cpu()  # Use CPU for gradient computation
    model.eval()

    # Compute theoretical RF
    theoretical_rf = compute_theoretical_receptive_field(model)
    rf_radius = theoretical_rf // 2
    print(f"Theoretical receptive field: {theoretical_rf}x{theoretical_rf}")

    test_image_index = 10
    model_evaluation_H5_folder = "/home/ml4pf/zshlan/FluxNet/dataset/spinodal_decomposition/test"

    for h5_file in os.listdir(model_evaluation_H5_folder):
        if h5_file.endswith(".h5"):
            h5_file_path = os.path.join(model_evaluation_H5_folder, h5_file)
            with h5py.File(h5_file_path, 'r') as f:
                phi_data = f['phi_data'][:]
                test_image = phi_data[test_image_index]
            break
    input_tensor = torch.from_numpy(test_image).float().unsqueeze(0).unsqueeze(0)

    h, w = test_image.shape

    # Generate sample points
    margin = max(5, rf_radius)
    points = []
    for _ in range(NUM_SAMPLE_POINTS):
        y = random.randint(margin, h - margin - 1)
        x = random.randint(margin, w - margin - 1)
        points.append((y, x))

    # Get channel info
    total_channels = model.total_channels
    channel_names = get_channel_names(model)

    print(f"Total channels: {total_channels}")
    print(f"Sample points: {NUM_SAMPLE_POINTS}")

    # Analyze each channel
    channel_results = []

    for channel_idx in range(total_channels):
        channel_name = channel_names[channel_idx]
        print(f"  [{channel_idx+1}/{total_channels}] Analyzing channel: {channel_name}")

        result = analyze_single_channel(
            model, input_tensor, points, channel_idx,
            theoretical_rf, rf_radius, h, w, THRESHOLD
        )
        result['name'] = channel_name
        channel_results.append(result)

        # Save NPY file
        safe_channel_name = channel_name.replace('(', '_').replace(')', '_').replace(',', '_')
        npy_path = os.path.join(model_output_dir, f"{model_name}_rf_ch{channel_idx:02d}_{safe_channel_name}.npy")
        np.save(npy_path, result['avg_rf'])

        print(f"    Effective RF: {result['avg_rf_size']:.2f}x{result['avg_rf_size']:.2f}, Ratio: {result['ratio']:.2f}%")

    # Calculate averages for outflow and inflow branches
    num_neighbors = model.num_neighbors

    # Outflow branch: channel 0 (sigmoid) + channels 1 to num_neighbors (softmax)
    outflow_results = channel_results[0:num_neighbors+1]
    outflow_avg_size = np.mean([r['avg_rf_size'] for r in outflow_results])

    # Inflow branch: channel num_neighbors+1 (sigmoid) + remaining channels (softmax)
    inflow_results = channel_results[num_neighbors+1:]
    inflow_avg_size = np.mean([r['avg_rf_size'] for r in inflow_results])

    # Overall average
    all_avg_size = np.mean([r['avg_rf_size'] for r in channel_results])
    all_avg_ratio = np.mean([r['ratio'] for r in channel_results])

    # Save statistics
    stats = {
        'model_name': model_name,
        'theoretical_rf': theoretical_rf,
        'outflow_avg_erf': outflow_avg_size,
        'inflow_avg_erf': inflow_avg_size,
        'all_avg_erf': all_avg_size,
        'all_avg_ratio': all_avg_ratio,
        'channel_results': channel_results
    }

    # Save statistics text file
    stats_file = os.path.join(model_output_dir, f"{model_name}_rf_statistics.txt")
    with open(stats_file, 'w', encoding='utf-8') as f:
        f.write(f"Receptive Field Analysis for {model_name}\n")
        f.write("=" * 100 + "\n\n")
        f.write(f"Theoretical RF Size: {theoretical_rf}x{theoretical_rf}\n")
        f.write(f"Number of Channels: {total_channels}\n")
        f.write(f"Number of Sample Points: {NUM_SAMPLE_POINTS}\n")
        f.write(f"Threshold: {THRESHOLD}\n\n")

        f.write("Individual Channel Results:\n")
        f.write("=" * 100 + "\n")
        f.write(f"{'Ch':<4} {'Channel Name':<50} {'Effective RF':>15} {'Ratio':>10}\n")
        f.write("-" * 100 + "\n")

        for i, result in enumerate(channel_results):
            f.write(f"{i:02d}   {result['name']:<50} "
                    f"{result['avg_rf_size']:>6.2f}x{result['avg_rf_size']:<6.2f} "
                    f"{result['ratio']:>9.2f}%\n")

        f.write("=" * 100 + "\n\n")
        f.write(f"Branch Averages:\n")
        f.write(f"  Outflow branch average ERF: {outflow_avg_size:.2f}x{outflow_avg_size:.2f}\n")
        f.write(f"  Inflow branch average ERF: {inflow_avg_size:.2f}x{inflow_avg_size:.2f}\n")
        f.write(f"\nAll Channels Average:\n")
        f.write(f"  Effective RF Size: {all_avg_size:.2f}x{all_avg_size:.2f}\n")
        f.write(f"  Ratio: {all_avg_ratio:.2f}%\n")

    print(f"\nResults saved to: {model_output_dir}")

    return stats


def generate_markdown_table(all_stats):
    """Generate markdown table with RF comparison"""
    md_content = """# Effective Receptive Field Analysis Results

## Summary Table

| Model | Theoretical RF | Outflow Branch ERF | Inflow Branch ERF | Average ERF | ERF/TRF Ratio |
|-------|----------------|--------------------|--------------------|-------------|---------------|
"""

    for model_name in ['10dt', '100dt', '1000dt']:
        if model_name in all_stats:
            stats = all_stats[model_name]
            md_content += f"| FluxNet-D ({model_name}) | {stats['theoretical_rf']}×{stats['theoretical_rf']} | "
            md_content += f"{stats['outflow_avg_erf']:.2f}×{stats['outflow_avg_erf']:.2f} | "
            md_content += f"{stats['inflow_avg_erf']:.2f}×{stats['inflow_avg_erf']:.2f} | "
            md_content += f"{stats['all_avg_erf']:.2f}×{stats['all_avg_erf']:.2f} | "
            md_content += f"{stats['all_avg_ratio']:.2f}% |\n"

    md_content += """
## Notes

- **Theoretical RF (TRF)**: Calculated based on model architecture (kernel sizes and number of layers)
- **Effective RF (ERF)**: Measured empirically using gradient-based analysis with 1% threshold
- **ERF/TRF Ratio**: Percentage of theoretical receptive field that is effectively utilized

## Analysis Parameters

- Sample points: 100
- Threshold: 1% of maximum gradient magnitude
- Image size: 256×256
- Random seed: 666

"""

    return md_content


def main():
    print("=" * 80)
    print("Effective Receptive Field Analysis for FluxNet_D Models")
    print("=" * 80)

    # Setup
    setup_seed(RANDOM_SEED)
    device = torch.device('cpu')  # Use CPU for gradient computation

    os.makedirs(OUTPUT_DIR, exist_ok=True)

    # Analyze each model
    all_stats = {}

    for model_name in ['10dt', '100dt', '1000dt']:
        model_path = MODEL_PATHS[model_name]
        model_config = MODEL_CONFIGS[model_name]

        stats = analyze_model(model_name, model_path, model_config, device)
        all_stats[model_name] = stats

    # Generate and save markdown table
    md_content = generate_markdown_table(all_stats)
    md_path = os.path.join(OUTPUT_DIR, 'erf_analysis_summary.md')
    with open(md_path, 'w', encoding='utf-8') as f:
        f.write(md_content)

    print(f"\n{'='*80}")
    print(f"Analysis complete!")
    print(f"Summary saved to: {md_path}")
    print(f"{'='*80}")


if __name__ == '__main__':
    main()
