import numpy as np
import torch
from sklearn.decomposition import PCA
from skdim.id import MLE, TwoNN

def get_gradient_vector(model, single_sample, single_label, criterion, layer_name='fc'):
    model.zero_grad()
    output = model(single_sample.unsqueeze(0))
    loss = criterion(output, single_label.unsqueeze(0))
    loss.backward()
    
    grads = [p.grad.view(-1) for name, p in model.named_parameters() 
             if p.grad is not None and layer_name in name]
    
    if not grads:
        raise ValueError(f"No gradients found for layer containing '{layer_name}'. Check layer name.")
        
    return torch.cat(grads).detach().cpu().numpy()

def get_coco_gradient_vector(model, img, lbl, criterion):
    model.head.zero_grad()
    with torch.no_grad():
        feats = model.backbone(img.unsqueeze(0))
    feats = feats.view(feats.size(0), -1)
    
    logits = model.head(feats)
    loss = criterion(logits, lbl.unsqueeze(0))
    loss.backward()
    
    grad_w = model.head.weight.grad.detach().cpu().view(-1)
    grad_b = model.head.bias.grad.detach().cpu().view(-1)
    return torch.cat([grad_w, grad_b]).numpy()

def estimate_id(embeddings, method='mle', pca_comps=10):
    if embeddings.ndim == 1:
        embeddings = embeddings.reshape(1, -1)
    if embeddings.shape[0] < 2:
        return 1.0

    if method.lower() == 'mle':
        n_samples, n_features = embeddings.shape
        n_components = min(pca_comps, n_samples - 1, n_features)
        if n_components <= 0: return 1.0
        
        pca = PCA(n_components=n_components)
        embeddings_pca = pca.fit_transform(embeddings)
        
        k = min(10, embeddings_pca.shape[0] - 1)
        if k < 1: return float(embeddings_pca.shape[1])
        
        estimator = MLE(k=k)
    
    elif method.lower() == 'twonn':
        estimator = TwoNN()
        embeddings_pca = embeddings # TwoNN does not require PCA
    
    else:
        raise ValueError(f"ID estimation method '{method}' not supported.")
        
    id_estimates = estimator.fit_transform(embeddings_pca)
    return np.mean(id_estimates)