import torch
import numpy as np
import torch.nn.functional as F
from sklearn.neighbors import NearestNeighbors
from sklearn.decomposition import PCA
from sklearn.neighbors import KernelDensity
from sklearn.preprocessing import StandardScaler
def calculate_forgettability_score(model, train_loader, history_predictions, device="mps"):

    forgettability_scores = {}

    model.eval()  
    with torch.no_grad():  
        for inputs, labels, sample_ids in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)

            for i in range(labels.size(0)):
                sample_id = sample_ids[i].item()

                if sample_id not in history_predictions:
                    history_predictions[sample_id] = []

                forget_count = 0
                correct_count = 0

                for j in range(1, len(history_predictions[sample_id])):
                    if history_predictions[sample_id][j-1] == labels[i].item() and history_predictions[sample_id][j] != labels[i].item():
                        forget_count += 1
                    if history_predictions[sample_id][j-1] == labels[i].item():
                        correct_count += 1

                if correct_count > 0:
                    forgettability_scores[sample_id] = forget_count / correct_count
                else:
                    forgettability_scores[sample_id] = 0.0
                history_predictions[sample_id].append(predicted[i].item())

    return forgettability_scores

def calculate_grand(model, train_loader, criterion, device="mps"):

    model.eval()
    grad_norms = {}
    for inputs, labels, sample_ids in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        inputs.requires_grad_()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
    
        model.zero_grad()
        if inputs.grad is not None:
            inputs.grad.zero_()

        loss.backward()

        batch_size = inputs.size(0)
        gradients = inputs.grad  
        gradients_flat = gradients.view(batch_size, -1) 
        grad_norm = gradients_flat.norm(p=2, dim=1) 

        for i in range(inputs.size(0)):
            sample_id = sample_ids[i].item()
            grad_norms[sample_id] = grad_norm[i].item() 
        inputs.grad = None

    return grad_norms

def calculate_density(latents):

    latents = latents.reshape(latents.shape[0], -1)
    pca = PCA(n_components=50, random_state=42)
    latents = pca.fit_transform(latents)


    scaler = StandardScaler()
    latents = scaler.fit_transform(latents)

    density_model = KernelDensity(kernel='gaussian', bandwidth=10)
    density_model.fit(latents)
    density_scores = density_model.score_samples(latents)
    return density_scores
    
def combine_scores(density_scores, forget_scores, grand_scores, 
                  density_weight=-1.0, forget_weight=1.0, grad_weight=0.5):

    def zscore_normalize(arr):
        eps = 1e-8
        mean = np.mean(arr)
        std = np.std(arr) + eps
        return (arr - mean) / std
    

    norm_density = zscore_normalize(density_scores) 
    norm_forget = zscore_normalize(np.array(list(forget_scores.values())))
    norm_grad = zscore_normalize(np.array(list(grand_scores.values())))
    
    combined = (
        density_weight * norm_density + 
        forget_weight * norm_forget + 
        grad_weight * norm_grad
    )

    return combined


def calculate_contrastive_scores(latents, probs, k=10, topk=50):

    latents = latents.reshape(latents.shape[0], -1)  # Reshape latents to (n, c*h*w)
    nbrs = NearestNeighbors(n_neighbors=k+1, algorithm='auto').fit(latents)
    distances, indices = nbrs.kneighbors(latents)

    scores = []
    for i, neighbors in enumerate(indices):
        neighbors = neighbors[1:] 
        sim = np.mean([
            np.linalg.norm(probs[i] - probs[j])
            for j in neighbors
        ])
        scores.append(sim)
    
    scores = np.array(scores)
    topk_indices = np.argpartition(-scores, topk)[:topk]
    return scores, topk_indices