import numpy as np
import torch
import torch.nn.functional as F

from collections import defaultdict
from sklearn.preprocessing import MinMaxScaler
from torchvision import transforms
from torch.utils.data import DataLoader
from utils.training import evaluate, AugmentedDataset

@torch.no_grad()
def compute_class_means(model, dataset, device):
    dataloader = DataLoader(dataset, batch_size=64)
    model.eval()
    class_to_features = defaultdict(list)

    for x, y, _ in dataloader:
        x, y = x.to(device), y.to(device)
        features = model._forward_features(x)  # shape: (B, D, 1, 1) or (B, D)

        if features.ndim == 4:
            features = features.squeeze(-1).squeeze(-1)

        for f, label in zip(features, y):
            class_to_features[int(label)].append(f.cpu())

    # Compute mean embedding per class
    class_means = {
        cls: torch.stack(feats).mean(0)
        for cls, feats in class_to_features.items()
    }

    return class_means


def compute_distance_matrix(buffer_means, val_means, distance_type='euclidean'):
    """
    Computes the distance between each class in buffer_means and each class in val_means.

    Args:
        buffer_means (dict[int, Tensor]): class_id -> mean embedding
        val_means (dict[int, Tensor]): class_id -> mean embedding
        distance_type (str): 'euclidean' or 'cosine'

    Returns:
        dict[tuple[int, int], float]: ((buffer_cls, val_cls) -> distance)
    """
    if distance_type == 'euclidean':
        distance_fn = lambda x, y: torch.norm(x - y, p=2).item()
    elif distance_type == 'cosine':
        distance_fn = lambda x, y: 1 - F.cosine_similarity(x.unsqueeze(0), y.unsqueeze(0)).item()
    else:
        raise ValueError(f"Unsupported distance_type: {distance_type}")

    distances = {}

    for buffer_cls, buf_vec in buffer_means.items():
        for val_cls, val_vec in val_means.items():
            distances[(buffer_cls, val_cls)] = distance_fn(buf_vec, val_vec)

    return distances


def get_distances_per_buffer_class(distance_dict, reduction='min'):
    """
    Aggregates distances from buffer classes to validation classes.

    Args:
        distance_dict: dict of {(buffer_cls, val_cls), distance}
        reduction: 'min', 'mean', or 'max'

    Returns:
        dict of {buffer_cls: aggregated_distance}
    """
    buffer_cls_distances = defaultdict(list)

    for (buffer_cls, val_cls), dist in distance_dict.items():
        buffer_cls_distances[buffer_cls].append(dist)

    aggregated = {}
    for cls, dists in buffer_cls_distances.items():
        dists_tensor = torch.tensor(dists)
        if reduction == 'min':
            aggregated[cls] = dists_tensor.min().item()
        elif reduction == 'mean':
            aggregated[cls] = dists_tensor.mean().item()
        elif reduction == 'max':
            aggregated[cls] = dists_tensor.max().item()
        else:
            raise ValueError(f"Unsupported reduction method: {reduction}")

    return aggregated


def compute_scores(trained_model, experience_set, calibration_buffer, device):
    """
    Computes buffer scores based on the distance between class means in the validation set and calibration buffer.
    
    Args:
        trained_model (torch.nn.Module): The trained model used for feature extraction and evaluation.
        val_class_means (dict[int, Tensor]): Class means for the validation set.
        experience_set (object): An object representing the current evaluation experience, containing dataset and class information.
        calibration_buffer (object): Buffer containing samples for calibration, with a `.buffer` attribute for data access.
        device (torch.device): The device (CPU or GPU) on which computations are performed.

    Returns:
        torch.Tensor: A tensor of buffer scores, normalized and aggregated per class, indicating the similarity between buffer and validation class means.
    """         
    current_classes = experience_set.classes_in_this_experience
    buffer_logits, buffer_labels = evaluate(trained_model, calibration_buffer.buffer, device)
    val_class_means = compute_class_means(trained_model, experience_set.dataset.with_transforms('eval'), device)
    buffer_class_means = compute_class_means(trained_model, calibration_buffer.buffer, device)
    distance_matrix = compute_distance_matrix(buffer_class_means, val_class_means, distance_type='cosine')
    class_distances = get_distances_per_buffer_class(distance_matrix, reduction='min')
    buffer_scores = torch.Tensor([class_distances[cls] for cls in buffer_labels.numpy()])

    values = np.array(list(class_distances.values())).reshape(-1, 1)
    values = MinMaxScaler().fit_transform(values).flatten()
    class_distances = {
        cls: float(scaled_val)
        for cls, scaled_val in zip(class_distances.keys(), values)
    }

    current_task_history = [class_distances[cls] for cls in current_classes]
    score_current_task = np.mean(current_task_history)
    prev_task_history = [class_distances[cls] for cls in experience_set.classes_seen_so_far if cls not in current_classes]
    score_prev_tasks = np.mean(prev_task_history) if prev_task_history else 0.0
    mask = torch.tensor([lbl.item() in current_classes for lbl in buffer_labels])
    buffer_scores = torch.full(mask.shape, score_current_task, dtype=torch.float32, device=mask.device)
    buffer_scores = buffer_scores * mask.float()
    buffer_scores[~mask] = score_prev_tasks  

    return buffer_logits, buffer_labels, buffer_scores, buffer_class_means, class_distances


def select_representative_classes(test_features, buffer_class_means, threshold=0.6):
    """
    Selects the most representative classes from the buffer based on cosine similarity 
    between test features and buffer class mean vectors.

    Args:
        test_features (Tensor): Tensor of shape (N, D) containing normalized feature vectors 
                                for N test samples.
        buffer_class_means (Dict[int, Tensor]): A dictionary mapping class indices to their 
                                                corresponding mean feature vectors (shape (D,)).
        threshold (float): Cumulative frequency threshold (in [0, 1]) that determines how many 
                           top classes to retain based on their assignment frequency.

    Returns:
        List[int]: A list of selected class indices that collectively account for at least 
                   `threshold` proportion of test samples based on nearest class assignments.
    """

    # Stack means into a matrix and track class ids
    class_ids = list(buffer_class_means.keys())
    class_mean_matrix = torch.stack([buffer_class_means[c] for c in class_ids])  # shape (C, D)

    # Normalize features
    test_features = F.normalize(test_features, dim=1)
    class_mean_matrix = F.normalize(class_mean_matrix, dim=1)

    # Compute cosine distance
    similarities = test_features @ class_mean_matrix.T
    distances = 1 - similarities

    # Assign closest class to each test sample
    _, min_indices = distances.min(dim=1)
    assigned_classes = torch.tensor([class_ids[i] for i in min_indices.tolist()])

    # Count and sort assigned class frequencies
    values, counts = torch.unique(assigned_classes, return_counts=True)
    sorted_counts, sorted_indices = torch.sort(counts, descending=True)
    sorted_values = values[sorted_indices]

    # Select top classes based on cumulative threshold
    cumulative = 0
    total_count = len(assigned_classes)
    selected_classes = []
    for cls, count in zip(sorted_values.tolist(), sorted_counts.tolist()):
        cumulative += count
        selected_classes.append(cls)
        if cumulative / total_count >= threshold:
            break

    return selected_classes


def assign_scores(shape, selected_classes, class_distances):
    """
    Assigns a score to each selected class based on the distances provided.

    Args:
        shape (tuple or torch.Size): The shape of the output tensor to be created.
        selected_classes (List[int]): List of class indices that were selected.
        class_distances (Dict[int, float]): Dictionary mapping class indices to their distance scores.

    Returns:
        torch.Tensor: A tensor containing the scores for the selected classes.
    """
    distance_list = [class_distances[cls] for cls in selected_classes]
    avg_distance_current_task = np.mean(distance_list)
    scores = torch.full(shape, avg_distance_current_task, dtype=torch.float32)

    return scores