"""
Neural Recording Datasets

Datasets containing neural recordings and corresponding visual stimuli
for evaluating model-brain alignment.

Datasets:
1. Allen Natural Scenes: Mouse V1 responses to natural images
2. Macaque Face: Macaque IT (AM region) responses to face stimuli
3. Macaque Synthetic: Macaque V4/IT responses to synthetic stimuli
"""

import os
import torch
import numpy as np
import torchvision.transforms as transforms
from util import DATA_DIR


class AllenNaturalScenes(object):
    """
    Allen Brain Observatory Natural Scenes Dataset.
    
    Contains mouse visual cortex responses to natural scene images
    across multiple visual areas (V1 and higher visual areas).
    
    Brain areas:
    - VISp (V1): Primary visual cortex
    - VISl: Lateral visual area
    - VISrl: Rostrolateral visual area
    - VISal: Anterolateral visual area
    - VISpm: Posteromedial visual area
    - VISam: Anteromedial visual area
    
    Args:
        threshold: Signal-to-noise threshold for neuron selection (default: 0.8)
    """
    brain_areas = ['visp', 'visl', 'visrl', 'visal', 'vispm', 'visam']
    base_dir = os.path.join(DATA_DIR, 'allen_natural_scenes')

    def __init__(self, threshold=0.8):
        self.neural_data = {}
        for area in self.brain_areas:
            # Load neural responses
            area_neural_data = np.load(os.path.join(self.base_dir, f'{area}_25.npy'))
            # Load signal-to-noise ratio for neuron selection
            shr = np.load(os.path.join(self.base_dir, f'shr_{area}.npy'))
            # Filter neurons by reliability
            area_neural_data = area_neural_data[:, :, shr >= threshold]
            # Normalize by trial duration and average across trials
            area_neural_data = area_neural_data / 50
            area_neural_data = area_neural_data.mean(axis=1)
            self.neural_data[area] = area_neural_data

        # Load stimulus images (224x224 resolution)
        self.stimulus_data = torch.load(
            os.path.join(self.base_dir, 'stimulus_allen_natural_scenes_224.pt')
        )


class MacaqueFace(object):
    """
    Macaque Face Processing Dataset.
    
    Contains macaque IT cortex (AM region) responses to face stimuli,
    useful for evaluating face-selective representations.
    
    Args:
        threshold: Noise ceiling threshold for neuron selection (default: 0.1)
    """
    brain_areas = ['AM']
    base_dir = os.path.join(DATA_DIR, 'macaque_face')

    def __init__(self, threshold=0.1):
        self.neural_data = {}
        for area in self.brain_areas:
            area_neural_data = np.load(os.path.join(self.base_dir, f"{area}.npy"))
            if threshold is not None:
                noise_ceiling = np.load(os.path.join(self.base_dir, "noise_ceiling.npy"))
                area_neural_data = area_neural_data[:, noise_ceiling >= threshold]
            self.neural_data[area] = area_neural_data

        self.stimulus_data = torch.load(
            os.path.join(self.base_dir, 'stimulus_macaque_face_224.pt')
        )


class MacaqueSynthetic(object):
    """
    Macaque V4/IT Synthetic Stimuli Dataset.
    
    Contains macaque V4 and IT cortex responses to synthetic visual stimuli,
    useful for evaluating hierarchical visual processing.
    
    Brain areas:
    - V4: Mid-level visual area
    - IT: Inferior temporal cortex (high-level visual area)
    """
    brain_areas = ['V4', 'IT']
    base_dir = os.path.join(DATA_DIR, 'macaque_synthetic')

    def __init__(self):
        self.neural_data = {}
        for area in self.brain_areas:
            area_neural_data = np.load(os.path.join(self.base_dir, f'{area}.npy'))
            self.neural_data[area] = area_neural_data

        self.stimulus_data = torch.load(
            os.path.join(self.base_dir, 'stimulus_macaque_synthetic_224.pt')
        )


class ImageDataset(torch.utils.data.Dataset):
    """
    PyTorch Dataset wrapper for stimulus images.
    
    Args:
        images_pt: Tensor of images
        resize: Optional target size for resizing
    """
    def __init__(self, images_pt, resize=None):
        self.images_pt = images_pt
        if resize is None:
            self.transform = None
        else:
            self.transform = transforms.Resize(resize)

    def __getitem__(self, index):
        if self.transform is None:
            return (self.images_pt[index], )
        else:
            return (self.transform(self.images_pt[index]), )

    def __len__(self):
        return self.images_pt.size(0)
