#!/usr/bin/env python3
# eval.py - Component-wise energy analysis functionality for integrated evaluation

import argparse
import torch
import datetime
from pathlib import Path
from thop import profile
import torch.nn as nn

# Local module imports
try:
    from train import (
        load_config, build_loaders, build_model_ddp,
        setup_distributed, cleanup_distributed,
        SPIKINGJELLY_FUNCTIONAL_AVAILABLE, functional
    )
    from model.layers import SpikeInfo # Required for dummy input generation
except ImportError as e:
    print(f"Error importing local modules: {e}")
    print("Please ensure train.py, model/layers.py and other required files are in the correct directory.")
    exit(1)

# Utility functions
ts = lambda: datetime.datetime.now().strftime("%H:%M:%S")

class PerformanceMonitor:
    """
    Performance monitoring class that uses forward hooks to collect layer-wise
    statistics during actual forward propagation for stable SOPs and energy estimation.
    
    This approach captures real activation patterns by monitoring input tensor
    statistics during inference, providing more accurate estimates than synthetic
    profiling methods.
    """
    def __init__(self, model):
        self.model = model.module if hasattr(model, 'module') else model
        self.monitored_modules = (nn.Conv2d, nn.Linear)
        self.layer_stats = {}
        self.hooks = []
        self._hook_map = {module: name for name, module in self.model.named_modules()}

    def _forward_hook(self, module, inputs, output):
        """
        Forward hook function to capture input activation rates during inference.
        
        For 5D tensors (T, B, C, H, W), computes temporal average to approximate
        actual spike rates in spiking neural networks.
        """
        module_name = self._hook_map.get(module, None)
        if not module_name: 
            return

        in_tensor = inputs[0] if isinstance(inputs, tuple) and len(inputs) > 0 and isinstance(inputs[0], torch.Tensor) else None
        if in_tensor is None or in_tensor.numel() == 0: 
            return

        # Handle temporal dimension in spiking networks
        rate_tensor = in_tensor.mean(0) if in_tensor.dim() == 5 else in_tensor
        input_rate = rate_tensor.abs().mean().item() # Absolute mean for activation rate

        self.layer_stats[module_name] = {
            'input_rate': input_rate,
            'output_shape': output.shape,
            'module': module,
        }

    def enable(self):
        """Enable monitoring by registering forward hooks on target modules."""
        self.disable()
        for name, module in self.model.named_modules():
            if isinstance(module, self.monitored_modules):
                self.hooks.append(module.register_forward_hook(self._forward_hook))
        print(f"{ts()} [Monitor] Enabled monitoring for {len(self.hooks)} layers.")

    def disable(self):
        """Disable monitoring by removing all registered hooks."""
        for hook in self.hooks:
            hook.remove()
        self.hooks.clear()

def analyze_model_performance(model, loader, T, device):
    """
    Comprehensive component-wise performance analysis including memory usage,
    computational complexity, and energy consumption estimation.
    
    Args:
        model: The neural network model to analyze
        loader: Data loader for generating representative inputs
        T: Number of time steps for temporal processing
        device: Computing device (CPU/GPU)
    
    Analysis includes:
    - GFLOPs/Parameters: Theoretical computation using thop profiler
    - SOPs/Energy: Practical estimation using actual activation monitoring
    """
    print(f"\n{ts()} [INFO] Component-wise Performance Analysis Started")
    actual_model = model.module if hasattr(model, 'module') else model

    # Define model components for analysis
    component_modules = {
        'SENC': actual_model.senc,  # Spike Encoding
        'MSP': actual_model.msp,    # Multi-Scale Processing
        'LI': actual_model.li,      # Lateral Inhibition
        'SC': actual_model.sc,      # Spike Classification
    }
    comp_stats = {name: {'params': 0, 'gflops': 0, 'sops': 0, 'energy': 0} for name in component_modules.keys()}
    
    # Generate representative input
    dummy_input = next(iter(loader))[0][:1].to(device)
    if dummy_input.dim() == 4:
        dummy_input = dummy_input.unsqueeze(1) # Ensure [B,T,C,H,W] format

    # Component-wise GFLOPs and parameter analysis using thop
    print(f"{ts()} [INFO] Computing GFLOPs and parameter counts for each component")
    with torch.no_grad():
        try:
            # SENC analysis
            macs, params = profile(component_modules['SENC'], inputs=(dummy_input,), verbose=False)
            comp_stats['SENC']['params'] = params
            comp_stats['SENC']['gflops'] = (macs * 2 / 1e9)
            
            # Generate intermediate outputs for subsequent component analysis
            out_senc = component_modules['SENC'](dummy_input)
            out_msp = component_modules['MSP'](out_senc)
            out_grouped = actual_model.grouper(out_msp)
            out_li = component_modules['LI'](out_grouped)

            # MSP analysis
            macs, params = profile(component_modules['MSP'], inputs=(out_senc,), verbose=False)
            comp_stats['MSP']['params'] = params
            comp_stats['MSP']['gflops'] = (macs * 2 / 1e9)

            # LI analysis
            macs, params = profile(component_modules['LI'], inputs=(out_grouped,), verbose=False)
            comp_stats['LI']['params'] = params
            comp_stats['LI']['gflops'] = (macs * 2 / 1e9)
            
            # SC analysis
            macs, params = profile(component_modules['SC'], inputs=(out_li,), verbose=False)
            comp_stats['SC']['params'] = params
            comp_stats['SC']['gflops'] = (macs * 2 / 1e9)
        except Exception as e:
            print(f"{ts()} [WARNING] THOP profiling failed: {e}. GFLOPs will be reported as zero.")

    # SOPs and energy estimation using performance monitor
    print(f"{ts()} [INFO] Computing SOPs and energy consumption using activation rate estimation")
    monitor = PerformanceMonitor(actual_model)
    monitor.enable()
    try:
        with torch.no_grad():
            functional.reset_net(actual_model)
            _ = actual_model(dummy_input)
    finally:
        monitor.disable()

    # Energy cost parameters (picojoules per operation)
    E_MAC, E_AC = 4.6, 0.9  # MAC operation energy, accumulation energy
    
    for layer_name, stats in monitor.layer_stats.items():
        # Determine component ownership
        owner = next((comp for comp in component_modules if layer_name.startswith(comp.lower())), None)
        if not owner: 
            continue

        sops = 0
        module, S_in = stats['module'], stats['input_rate']
        if S_in <= 1e-6: 
            continue

        # Calculate Synaptic Operations (SOPs) based on layer type
        if isinstance(module, nn.Conv2d):
            H_out, W_out = stats['output_shape'][-2], stats['output_shape'][-1]
            sops = H_out * W_out * module.out_channels * module.in_channels * module.kernel_size[0] * module.kernel_size[1] * S_in
        elif isinstance(module, nn.Linear):
            sops = module.out_features * module.in_features * S_in
        
        if sops > 0:
            # Energy cost depends on operation type
            energy_cost = E_MAC if 'senc.patch_conv' in layer_name else E_AC
            layer_energy = energy_cost * sops * T # Total energy across all time steps (pJ)
            comp_stats[owner]['sops'] += sops
            comp_stats[owner]['energy'] += layer_energy

    # Results compilation and reporting
    total_params = sum(s['params'] for s in comp_stats.values())
    total_gflops = sum(s['gflops'] for s in comp_stats.values())
    total_sops = sum(s['sops'] for s in comp_stats.values())
    total_energy = sum(s['energy'] for s in comp_stats.values())

    print("\n" + "=" * 80)
    print("Component-wise Performance Analysis Results")
    print("-" * 80)
    print(f"{'Component':<12} | {'Params (M)':>12} | {'GFLOPs':>10} | {'G-SOPs/call':>12} | {'Energy (mJ)':>12} | {'Energy (%)':>10}")
    print("-" * 80)

    for name, stats in comp_stats.items():
        p = stats['params'] / 1e6
        gflops = stats['gflops']
        gsops = stats['sops'] / 1e9
        energy_mj = stats['energy'] / 1e9 # Convert pJ to mJ
        energy_percent = (stats['energy'] / total_energy * 100) if total_energy > 0 else 0
        print(f"{name:<12} | {p:>12.4f} | {gflops:>10.4f} | {gsops:>12.4f} | {energy_mj:>12.6f} | {energy_percent:>9.2f}%")

    print("-" * 80)
    print(f"{'TOTAL':<12} | {total_params/1e6:>12.4f} | {total_gflops:>10.4f} | {total_sops/1e9:>12.4f} | {total_energy/1e9:>12.6f} | {'100.00%':>10}")
    print("=" * 80)

def extract_time_steps(cfg, loader):
    """
    Extract time steps from configuration or infer from data loader.
    
    Args:
        cfg: Configuration object
        loader: Data loader
        
    Returns:
        int: Number of time steps
    """
    if hasattr(cfg, 'time_steps') and cfg.time_steps: 
        return cfg.time_steps
    try:
        x, _ = next(iter(loader))
        if x.dim() == 5: 
            return x.shape[1]
    except StopIteration: 
        pass
    return 16 # Default fallback value

def cli():
    """Command line interface for the evaluation script."""
    p = argparse.ArgumentParser("STRAW Component-wise Analysis Tool")
    p.add_argument("--cfg", required=True, help="Path to configuration file.")
    p.add_argument("--resume", required=True, help="Path to model checkpoint.")
    p.add_argument("--output_dir", required=True, help="Path to output directory.")
    p.add_argument("--batch_size", type=int, default=None, help="Override evaluation batch size.")
    p.add_argument("--local_rank", type=int, default=0, help="Local rank for distributed training.")
    return p.parse_args()

def main():
    """Main execution function for component-wise performance analysis."""
    args = cli()
    cfg = load_config(args.cfg)
    cfg.output_dir = args.output_dir
    
    if args.batch_size:
        cfg.data.batch_size = args.batch_size
        print(f"[INFO] Batch size overridden to {args.batch_size}.")

    Path(cfg.output_dir).mkdir(parents=True, exist_ok=True)
    
    # Distributed training setup
    is_dist, rank, world, lrank = setup_distributed()
    device = torch.device(f"cuda:{lrank}" if torch.cuda.is_available() else "cpu")

    if rank == 0:
        print(f"{ts()} [INFO] Configuration: {args.cfg}, Device: {device}")

    # Data loader and model initialization
    _, val_loader, _ = build_loaders(cfg, distributed=is_dist, rank=rank, world_size=world)
    model = build_model_ddp(cfg, device, rank, world, debug=False)
    
    # Model checkpoint loading
    if rank == 0:
        print(f"{ts()} [INFO] Loading checkpoint from {args.resume}")
        try:
            ckpt = torch.load(args.resume, map_location=device, weights_only=False)
            state_dict = ckpt.get('model', ckpt)
            
            # Handle DDP wrapper in state dict
            if all(k.startswith('module.') for k in state_dict.keys()):
                state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
            
            model_to_load = model.module if hasattr(model, 'module') else model
            model_to_load.load_state_dict(state_dict, strict=False)
            print(f"{ts()} [INFO] Model loaded successfully.")
        except Exception as e:
            print(f"{ts()} [ERROR] Failed to load checkpoint: {e}")
            if is_dist: 
                cleanup_distributed()
            return

    if is_dist: 
        torch.distributed.barrier()

    # Execute performance analysis
    T = extract_time_steps(cfg, val_loader)
    analyze_model_performance(model, val_loader, T, device)
    
    if is_dist: 
        cleanup_distributed()

if __name__ == "__main__":
    main()
