from util_funcs import (
    batch_samples_and_tasks_sim_func,
    weighted_sae_feature_sim,
    SaeFeature,
    sae_features_to_tensor,
)
from functools import partial
from tqdm import tqdm
import torch

FEATURE_NUM = 131072
BATCH_SIZE = 32


def topk_sae_feature(sae_feature, k=-1):
    if k <= 0:
        return sae_feature

    acts, indices = sae_feature.acts, sae_feature.indices

    topk_acts, _topk_indices = torch.topk(acts, k, dim=-1, largest=True)
    topk_indices = torch.gather(indices, dim=-1, index=_topk_indices)

    return SaeFeature(acts=topk_acts, indices=topk_indices)


def get_metric_func(metric, k=-1):
    if "weighted_sae" in metric:
        aggregate_task_rep_mode = None
        if "@" in metric:
            metric, aggregate_task_rep_mode = metric.split("@")
        dist_metric = metric.split("-")
        if len(dist_metric) >= 2:
            dist_metric = dist_metric[1]
        else:
            dist_metric = "jaccard"

        def disk_data_to_sae_feature_list(data):
            if "acts" in data:
                return [
                    topk_sae_feature(
                        SaeFeature(
                            acts=acts,
                            indices=indices,
                        ),
                        k,
                    )
                    for acts, indices in zip(data["acts"], data["indices"])
                ]
            else:
                return {
                    task: [
                        topk_sae_feature(
                            SaeFeature(
                                acts=data[task]["acts"],
                                indices=data[task]["indices"],
                            ),
                            k,
                        )
                    ]
                    for task in data
                }

        def _sim_func(src_feature, task_features):
            task_features_tensor = {
                task: sae_features_to_tensor(task_features[task], FEATURE_NUM)
                for task in task_features
            }
            similarity = []
            start = 0

            pbar = tqdm(total=(len(src_feature) - 1) // BATCH_SIZE + 1)

            while start < len(src_feature):
                end = min(start + BATCH_SIZE, len(src_feature))
                batch_similarity = dict()
                for task in task_features_tensor:
                    batch_similarity[task] = weighted_sae_feature_sim(
                        src_feature[start:end],
                        task_features_tensor[task],
                        FEATURE_NUM,
                        dist_metric,
                        aggregate_task_rep_mode=aggregate_task_rep_mode,
                    )

                for i in range(end - start):
                    similarity.append(
                        {task: batch_similarity[task][i] for task in batch_similarity}
                    )

                start = end
                pbar.update(1)

            return similarity
            # return max(
            #     [
            #         weighted_sae_feature_sim(src_feature, task_feature)
            #         for task_feature in task_features
            #     ]
            # )

        return lambda batch_disk_data, task_disk_data: _sim_func(
            disk_data_to_sae_feature_list(batch_disk_data),
            disk_data_to_sae_feature_list(task_disk_data),
        )

        return lambda batch_disk_data, task_disk_data: partial(
            batch_samples_and_tasks_sim_func, sim_func=_sim_func
        )(
            disk_data_to_sae_feature_list(batch_disk_data),
            task_disk_data,
            # disk_data_to_sae_feature_list(task_disk_data),
        )
    elif metric == "cosine":

        def cosine_sim_func(src_features, task_features):
            task_sims = dict()
            for task in task_features:
                # [task_sample_num, dim]
                task_feature = task_features[task]
                sim = src_features @ task_feature.T.mean(dim=-1)
                task_sims[task] = sim

            # encapsulate into a list of dictionary
            sims = []
            for i in range(src_features.size(0)):
                sims.append({task: task_sims[task][i].item() for task in task_sims})
            return sims

        return cosine_sim_func
