import torch
from collections import namedtuple
from typing import Union, List
from tqdm import tqdm

SaeFeature = namedtuple("SaeFeature", ["acts", "indices"])


def batch_samples_and_tasks_sim_func(batch_features, task_features, sim_func=None):
    if isinstance(batch_features, torch.Tensor):
        batch_size = batch_features.size(0)
    elif isinstance(batch_features, list):
        batch_size = len(batch_features)

    similarity = []
    for i in tqdm(range(batch_size)):
        src_feature = batch_features[i]

        if isinstance(task_features, dict):
            sample_sim_to_tasks = {
                task: sim_func(src_feature, task_features[task])
                for task in sorted(task_features.keys())
            }
        else:
            sample_sim_to_tasks = [sim_func(src_feature, task_features)]

        similarity.append(sample_sim_to_tasks)

    return similarity


def list_extend_all_heldout_data_feature_agg(
    agg_heldout_data_feature, batch_heldout_data_feature, *args
):
    if agg_heldout_data_feature is None:
        agg_heldout_data_feature = batch_heldout_data_feature
    else:
        agg_heldout_data_feature.extend(batch_heldout_data_feature)

    return agg_heldout_data_feature, -1


def concat_extend_all_heldout_data_feature_agg(
    agg_heldout_data_feature, batch_heldout_data_feature, *args
):
    if agg_heldout_data_feature is None:
        agg_heldout_data_feature = batch_heldout_data_feature
    else:
        agg_heldout_data_feature = torch.cat(
            [agg_heldout_data_feature, batch_heldout_data_feature], dim=0
        )

    return agg_heldout_data_feature, -1


def weighted_topk_sae_feature_agg_func(
    sae_features: Union[torch.Tensor, List[torch.Tensor]],
    k: int = -1,
    avg_level: str = "sample",
):
    """
    Aggregate the SAE features using the weighted Top-K approach.

    Args:
        feature_acts (torch.Tensor): Tensor of feature activations. [*, act_dim]
        feature_indices (torch.Tensor): Tensor of feature indices. [*, act_dim]
        k (int): Number of top elements to extract.
    """
    if isinstance(sae_features, list):
        num = len(sae_features)
        # Flatten the feature and index tensors
        flat_feature_acts = torch.cat(
            [sae_feature.acts for sae_feature in sae_features], dim=0
        )
        flat_feature_indices = torch.cat(
            [sae_feature.indices for sae_feature in sae_features], dim=0
        )
    else:
        num = sae_features.acts.size(0)
        # Flatten the feature and index tensors
        flat_feature_acts = sae_features.acts.flatten()
        flat_feature_indices = sae_features.indices.flatten()

    # Find the maximum index for tensor allocation
    max_index = flat_feature_indices.max().item() + 1

    # Accumulate values using tensor indexing
    value_sum = torch.zeros(
        max_index, dtype=flat_feature_acts.dtype, device=flat_feature_acts.device
    )
    if avg_level == "feature":
        cnt_sum = torch.zeros(
            max_index, dtype=flat_feature_acts.dtype, device=flat_feature_acts.device
        )
    value_sum.index_add_(0, flat_feature_indices, flat_feature_acts)
    if avg_level == "feature":
        cnt_sum.index_add_(0, flat_feature_indices, torch.ones_like(flat_feature_acts))

    if avg_level == "feature":
        avg_value = value_sum / cnt_sum
    elif avg_level == "sample":
        avg_value = value_sum / num
    else:
        avg_value = value_sum

    if k > 0:
        # Extract Top-K values and indices
        topk_acts, topk_indices = torch.topk(avg_value, k)
        return SaeFeature(topk_acts, topk_indices)
    else:
        # non-zero values and indices
        non_zero_indices = torch.where(avg_value > 0)[0]
        return SaeFeature(avg_value[non_zero_indices], non_zero_indices)
        # return SaeFeature(avg_value, torch.arange(max_index, device=avg_value.device))


def concat_sae_feature(sae_features):
    """
    Concatenate the SAE features.

    Args:
        sae_features (list): List of SAE features.

    Returns:
        SaeFeature: Concatenated SAE features.
    """
    if len(sae_features) == 0:
        return None

    return SaeFeature(
        torch.cat([sae_feature.acts for sae_feature in sae_features], dim=-1),
        torch.cat([sae_feature.indices for sae_feature in sae_features], dim=-1),
    )


def stack_sae_feature(sae_features):
    """
    Stack the SAE features.

    Args:
        sae_features (list): List of SAE features.

    Returns:
        SaeFeature: Stacked SAE features.
    """
    if len(sae_features) == 0:
        return None

    return SaeFeature(
        torch.stack([sae_feature.acts for sae_feature in sae_features], dim=0),
        torch.stack([sae_feature.indices for sae_feature in sae_features], dim=0),
    )


def weighted_sae_feature_sim_old(sae_feature1, sae_feature2, metric="jaccard"):
    """
    Compute the similarity between two SAE features using the weighted Jaccard similarity.

    Args:
        sae_feature1 (SaeFeature): The first SAE feature, containing `acts` (weights) and `indices` (IDs).
        sae_feature2 (SaeFeature): The second SAE feature, containing `acts` (weights) and `indices` (IDs).

    Returns:
        float: The weighted Jaccard similarity score.
    """
    indices1 = sae_feature1.indices
    indices2 = sae_feature2.indices

    # Acts correspond to weights
    acts1 = sae_feature1.acts
    acts2 = sae_feature2.acts

    # Create a union of all indices (sorted for searchsorted)
    all_indices, inverse_indices = torch.unique(
        torch.cat([indices1, indices2]), sorted=True, return_inverse=True
    )

    # Map indices to their positions in all_indices
    pos1 = inverse_indices[: len(indices1)]
    pos2 = inverse_indices[len(indices1) :]

    # Initialize weight vectors for all indices
    weights1 = torch.zeros(len(all_indices), dtype=acts1.dtype, device=acts1.device)
    weights2 = torch.zeros(len(all_indices), dtype=acts2.dtype, device=acts2.device)

    # Fill weights for feature 1 and feature 2
    weights1.index_add_(0, pos1, acts1)
    weights2.index_add_(0, pos2, acts2)

    if metric == "jaccard":
        # Compute intersection and union
        intersection = torch.sum(torch.min(weights1, weights2))
        union = torch.sum(torch.max(weights1, weights2))

        # Return the similarity score
        return (intersection / union).item() if union > 0 else 0.0
    elif metric == "cosine":
        # normalize
        weights1 = torch.nn.functional.normalize(weights1, p=2)
        weights2 = torch.nn.functional.normalize(weights2, p=2)
        return torch.sum(weights1 * weights2).item()
    else:
        # Euclidean
        # NOTE: we do not sqrt since we only compare these distances
        return torch.sum(torch.square(weights1 - weights2))


def sae_features_to_tensor(features, feature_num):
    tensor = torch.zeros(
        len(features),
        feature_num,
        device=features[0].acts.device,
        dtype=features[0].acts.dtype,
    )
    for i, feature in enumerate(features):
        tensor[i, feature.indices] = feature.acts
    return tensor


def weighted_sae_feature_sim(
    batch_features,
    task_features_tensor,
    feature_num,
    metric,
    aggregate_task_rep_mode=None,
):
    """
    Compute the similarity between two SAE features using the weighted Jaccard similarity.

    Args:
        batch_features (list): The first SAE feature, containing `acts` (weights) and `indices` (IDs).
        task_features_tensor (torch.Tensor, [task_sample_num, feature_num]): The second SAE feature, containing `acts` (weights) and `indices` (IDs).

    Returns:
        float: The weighted Jaccard similarity score.
    """
    # [bsz, feature_num]
    batch_features = sae_features_to_tensor(batch_features, feature_num)

    if aggregate_task_rep_mode is not None:
        # mean
        if aggregate_task_rep_mode == "mean":
            task_features_tensor = task_features_tensor.mean(dim=0, keepdim=True)
        elif aggregate_task_rep_mode == "max":
            task_features_tensor = task_features_tensor.max(dim=0, keepdim=True)[0]

    if metric == "jaccard":
        intersection = torch.sum(
            torch.min(batch_features.unsqueeze(1), task_features_tensor.unsqueeze(0)),
            dim=-1,
        )  # [batch_size, task_sample_num]
        union = torch.sum(
            torch.max(batch_features.unsqueeze(1), task_features_tensor.unsqueeze(0)),
            dim=-1,
        )  # [batch_size, task_sample_num]

        return torch.mean(intersection / union, dim=-1).tolist()
    elif metric == "cosine":
        batch_features = torch.nn.functional.normalize(batch_features, p=2, dim=-1)
        task_features_tensor = torch.nn.functional.normalize(
            task_features_tensor, p=2, dim=-1
        )
        return torch.mean(batch_features @ task_features_tensor.T, dim=-1).tolist()
    else:
        # Euclidean
        return torch.mean(
            torch.sqrt(
                torch.sum(
                    torch.square(
                        torch.square(
                            batch_features.unsqueeze(1)
                            - task_features_tensor.unsqueeze(0)
                        )
                    ),
                    dim=-1,
                )
            ),
            dim=-1,
        ).tolist()


if __name__ == "__main__":
    feature1 = SaeFeature(
        acts=torch.tensor([2.0, 0.5, 0.2]), indices=torch.tensor([2, 1, 3])
    )
    feature2 = SaeFeature(
        acts=torch.tensor([0.8, 0.6, 0.4]), indices=torch.tensor([2, 3, 4])
    )

    print(weighted_sae_feature_sim_old(feature1, feature2))
    print(
        weighted_sae_feature_sim(
            [feature1], sae_features_to_tensor([feature2], 131072), 131072
        )
    )
