"""
Functions for extracting features and logits from models
"""

import torch
import numpy as np
import torch.nn.functional as F
from torch.nn import ModuleList

def extract_features(model, x, batch_size=100, dataset='cifar100'):
    """
    Extract features and logits from model
    
    Args:
        model (nn.Module): model to extract features from
        x (torch.Tensor): input data
        batch_size (int): batch size
        layer_name (str): layer name to extract features from
    
    Returns:
        tuple: (features, logits) tensor
    """
    model.eval()
    features = []
    logits_list = []
    
    # Set up hook for feature extraction
    activations = {}
    
    def get_activation(name):
        def hook(model, input, output):
            activations[name] = output.detach()
        return hook
    
    # Set up hook according to WRN model structure
    if dataset == 'cifar100':
        base_model = model.model if hasattr(model, 'model') else model
        layer_name = 'relu'
    elif dataset == 'tiny_imagenet':
        base_model = model.model.model if hasattr(model.model, 'model') else model.model
        layer_name = 'avgpool'
    else:
        raise ValueError(f"Dataset {dataset} not supported")
    
    feature_layer = getattr(base_model, layer_name)
    feature_layer.register_forward_hook(get_activation(layer_name))    
    
    with torch.no_grad():
        for i in range(0, x.size(0), batch_size):
            batch = x[i:i+batch_size].cuda()
            logits = model(batch)
            
            if layer_name in activations:
                batch_features = activations[layer_name]
                # Reduce dimensions (if needed)
                if len(batch_features.shape) > 2:
                    batch_features = torch.mean(batch_features, dim=[2, 3])
                    
                features.append(batch_features.cpu())
                logits_list.append(logits.cpu())
    
    features_tensor = torch.cat(features, dim=0)
    logits_tensor = torch.cat(logits_list, dim=0)
    
    return features_tensor, logits_tensor

def extract_features_by_layer(model, x, layer_name, batch_size=500, dataset='cifar100'):
    """
    Extract features from a specific layer - same result as extract_features
    Args:
        model (nn.Module): model to extract features from
        x (torch.Tensor): input data
        layer_name (str): layer name to extract features from
        batch_size (int): batch size
        dataset (str): dataset name
    Returns:
        tuple: (features, logits) tensor
    """
    # same as extract_features
    
    if dataset == 'cifar100':
        base_model = model.model if hasattr(model, 'model') else model
        default_layer = 'relu'
    elif dataset == 'tiny_imagenet':
        base_model = model.model.model if hasattr(model.model, 'model') else model.model
        default_layer = 'avgpool'
    else:
        raise ValueError(f"Dataset {dataset} not supported")
    # if 'avgpool' is input, use default layer
    if layer_name == 'avgpool' and dataset == 'tiny_imagenet':
        target_layer_name = default_layer
    elif layer_name == 'relu' and dataset == 'cifar100':
        target_layer_name = default_layer
    else:
        target_layer_name = layer_name
    # variables for feature extraction
    features = []
    logits_list = []
    # hook for feature extraction
    activations = {}
    def get_activation(name):
        def hook(module, input, output):
            if isinstance(output, tuple):
                activations[name] = output[0].detach()
            else:
                activations[name] = output.detach()
        return hook
    # same as extract_features
    try:
        if target_layer_name == default_layer:
            feature_layer = getattr(base_model, target_layer_name)
        else:
            layer_parts = target_layer_name.split('.')
            current_module = base_model
            for part in layer_parts:
                if hasattr(current_module, part):
                    current_module = getattr(current_module, part)
                elif part.isdigit():
                    current_module = current_module[int(part)]
                else:
                    raise ValueError(f"Layer '{target_layer_name}' not found")
            feature_layer = current_module
        hook = feature_layer.register_forward_hook(get_activation(target_layer_name))
    except AttributeError:
        raise ValueError(f"Layer '{target_layer_name}' not found")
    x_size = x.shape[0] if isinstance(x, np.ndarray) else x.size(0)
    with torch.no_grad():
        for i in range(0, x_size, batch_size):
            batch = x[i:i+batch_size]
            batch = batch.cuda()
            if isinstance(batch, np.ndarray):
                batch = torch.from_numpy(batch).float()

            batch = batch.cuda()
            outputs = model(batch)
            if isinstance(outputs, dict) or hasattr(outputs, 'logits'):
                logits = outputs.logits
            else:
                logits = outputs
            
            if target_layer_name in activations:
                batch_features = activations[target_layer_name]
                if len(batch_features.shape) > 3:
                    batch_features = F.adaptive_avg_pool2d(batch_features, 1).squeeze(-1).squeeze(-1)
                elif len(batch_features.shape) == 3: # for vit and swin
                    if batch_features.shape[2] == 1: #swin
                        batch_features = batch_features.squeeze(2)
                    else: #vit
                        batch_features = batch_features.mean(dim=1)
                features.append(batch_features.cpu())
                logits_list.append(logits.cpu())    
    hook.remove()
    features_tensor = torch.cat(features, dim=0)
    logits_tensor = torch.cat(logits_list, dim=0)
    return features_tensor, logits_tensor