"""
Neural Benchmark Evaluation Framework

Provides a unified framework for evaluating model representations against
neural recordings from multiple brain datasets.

Key features:
- Layer-wise feature extraction using forward hooks
- Automatic caching of intermediate representations
- Support for both ANN and SNN models
- Multiple similarity metrics (RSA, CCA, Regression)
"""

import os
import sys
import time
import fcntl
import pickle
import random
import string
import numpy as np
from tqdm import tqdm

import torch

from dataset import AllenNaturalScenes, MacaqueFace, MacaqueSynthetic, ImageDataset
from metric import RSAMetric, RegMetric, CCAMetric
from util import BASE_DIR


# Dataset and metric registries
DatasetDict = {
    'allen': AllenNaturalScenes,
    'mac_face': MacaqueFace,
    'mac_syn': MacaqueSynthetic
}

MetricDict = {
    'rsa': RSAMetric,
    'reg': RegMetric,
    'cca': CCAMetric
}


class StaticBenchmark(object):
    """
    Benchmark for evaluating model-brain alignment on static stimuli.
    
    Workflow:
    1. Register forward hooks on model layers
    2. Run model inference on neural dataset stimuli
    3. Cache layer activations to disk
    4. Evaluate similarity between model and neural representations
    
    Args:
        datasets: List of dataset names to evaluate (default: all)
        metrics: List of metric names to use (default: all)
        cache_str: Custom cache directory name (default: random)
        timestep: Number of timesteps for SNN (None for ANN)
        mean: Whether to average across timesteps for SNN
    """
    def __init__(self, datasets=None, metrics=None, cache_str=None, timestep=None, mean=False):
        if datasets is None:
            datasets = list(DatasetDict.keys())
        self.datasets = [(k, DatasetDict[k]()) for k in datasets]
        if metrics is None:
            metrics = list(MetricDict.keys())
        self.metrics = metrics
        self.cache_str = cache_str
        self.timestep = timestep
        self._mean_timestep = mean

    def _acquire_cache_lock(self):
        """Create and lock a cache directory for storing intermediate results."""
        if self.cache_str is None:
            self.cache_str = ''.join(
                random.choice(string.ascii_letters + string.digits) for _ in range(16)
            )
        cache_dir = os.path.join(BASE_DIR, '.cache', self.cache_str)
        os.makedirs(cache_dir)
        print('Current cache dir:', cache_dir)

        try:
            lock_fd = open(os.path.join(cache_dir, '.lock'), 'w')
            fcntl.flock(lock_fd, fcntl.LOCK_EX | fcntl.LOCK_NB)
            return lock_fd, cache_dir
        except BlockingIOError:
            raise RuntimeError("Benchmark is already running and cannot be called again.")

    def _release_cache_lock(self, lock_fd, cache_dir):
        """Release cache directory lock."""
        if lock_fd:
            fcntl.flock(lock_fd, fcntl.LOCK_UN)
            lock_fd.close()
        if cache_dir:
            os.remove(os.path.join(cache_dir, '.lock'))

    def _hook_layers(self, vision_encoder, cache_dir, layers_fd, hook_handles):
        """
        Register forward hooks on residual blocks to capture layer outputs.
        
        For SNNs, outputs are either averaged across timesteps or reshaped
        to (batch * timesteps, features) depending on configuration.
        """
        def hook_fn_with_name(module_name):
            def hook_fn(_module, _input, output):
                if self.timestep is not None:
                    batch_size = output.size(1)
                    if self._mean_timestep:
                        output = output.mean(dim=0)
                    else:
                        output = output.reshape(
                            batch_size * self.timestep, 
                            output.size(2), output.size(3), output.size(4)
                        )
                else:
                    batch_size = output.size(0)
                output = output.reshape(batch_size, -1).cpu().numpy()
                pickle.dump(output, layers_fd[module_name])
            return hook_fn

        # Import appropriate block types based on model type
        if self.timestep is not None:
            from networks.resnet_snn import BasicBlock, Bottleneck
        else:
            from networks.resnet_ann import BasicBlock, Bottleneck

        # Register hooks on all residual blocks
        for name, layer in vision_encoder.named_modules():
            if isinstance(layer, BasicBlock) or isinstance(layer, Bottleneck):
                layers_fd[name] = open(os.path.join(cache_dir, f'{name}.out'), 'wb')
                handle = layer.register_forward_hook(hook_fn_with_name(name))
                hook_handles.append(handle)

    def _unhook_layers(self, layers_fd, hook_handles):
        """Remove forward hooks and close cache files."""
        for fd in layers_fd.values():
            fd.close()
        for handle in hook_handles:
            handle.remove()

    def __call__(self, visual_model, img_size, batch_size=256):
        """
        Run model inference on all datasets and cache layer outputs.
        
        Args:
            visual_model: PyTorch model to evaluate
            img_size: Input image size (for resizing)
            batch_size: Batch size for inference
            
        Returns:
            cache_dir: Path to cached layer outputs
        """
        visual_model.eval()

        lock_fd, cache_dir = None, None
        layers_fd, hook_handles = {}, []

        try:
            lock_fd, cache_dir = self._acquire_cache_lock()
            self._hook_layers(visual_model, cache_dir, layers_fd, hook_handles)

            with torch.no_grad():
                for dataset_name, dataset in self.datasets:
                    start_time = time.time()
                    stimulus_dataset = ImageDataset(dataset.stimulus_data, img_size)
                    stimulus_loader = torch.utils.data.DataLoader(
                        stimulus_dataset, batch_size=batch_size
                    )

                    for (stimulus, ) in tqdm(stimulus_loader, desc=dataset_name):
                        if self.timestep is not None:
                            # Add temporal dimension for SNN
                            stimulus = stimulus.unsqueeze(0).repeat(
                                self.timestep, 1, 1, 1, 1
                            )
                        visual_model(stimulus.cuda())
                    
                    time_elapsed = time.time() - start_time
                    print(f'Time elapsed: {time_elapsed:.2f}s for {dataset_name}')

        finally:
            self._unhook_layers(layers_fd, hook_handles)
            self._release_cache_lock(lock_fd, cache_dir)
        
        return cache_dir

    def evaluate(self, save_f=None, metrics=None):
        """
        Evaluate similarity between cached model outputs and neural data.
        
        Args:
            save_f: File handle to save results (optional)
            metrics: List of metrics to use (default: all configured)
            
        Returns:
            results: List of (layer, dataset, metric, brain_area, score) tuples
        """
        assert self.cache_str is not None
        cache_dir = os.path.join(BASE_DIR, '.cache', self.cache_str)
        assert os.path.exists(cache_dir)

        if metrics is None:
            metrics = self.metrics
        metrics = [(k, MetricDict[k]) for k in metrics]

        # Collect cached layer outputs
        layers = []
        for name in os.listdir(cache_dir):
            file_path = os.path.join(cache_dir, name)
            if os.path.isfile(file_path) and name.endswith('.out'):
                layers.append((name[:-4], file_path))

        # Sort by layer depth
        layers.sort(key=lambda x: list(map(int, x[0][-3:].split('.'))))

        results = []
        t = tqdm(layers, desc='Evaluating layers', total=len(layers))
        
        for layer_name, file_path in t:
            with open(file_path, 'rb') as f:
                for dataset_name, dataset in self.datasets:
                    total_size = dataset.stimulus_data.size(0)
                    layer_outputs = []
                    
                    # Load cached outputs
                    while total_size > 0:
                        layer_output = pickle.load(f)
                        total_size -= layer_output.shape[0]
                        layer_outputs.append(layer_output)
                    layer_outputs = np.vstack(layer_outputs)

                    # Evaluate each metric
                    for metric_name, metric_cls in metrics:
                        # Adjust CCA settings for different datasets
                        if metric_name == 'cca' and dataset_name != 'allen':
                            metric = metric_cls(neural_reduction=False)
                        else:
                            metric = metric_cls()

                        # Evaluate each brain area
                        for brain_area, area_neural_data in dataset.neural_data.items():
                            score = metric.score(layer_outputs, area_neural_data)
                            result = (layer_name, dataset_name, metric_name, brain_area, score)
                            t.set_description(
                                f'{layer_name}, {dataset_name}, {metric_name}, '
                                f'{brain_area}, {score:.4f}'
                            )
                            if save_f:
                                pickle.dump(result, save_f)
                            results.append(result)
        
        return results
