import torch
import torch.nn as nn
from .scores import attn_score, similarity_score, norm_score, k_norm_score, v_norm_score


def score_evaluate(cache_dic, tokens, current) -> torch.Tensor:
    """
    Return the score tensor (B, N) for the given tokens.
    """

    # if ((not current['is_force_fresh']) and (cache_dic['force_fresh'] == 'local')):
    #    # abandoned branch, if you want to explore the local force fresh strategy, this may help.
    #    force_fresh_mask = torch.as_tensor((cache_dic['cache_index'][-1][current['layer']][current['module']] >= 2 * cache_dic['fresh_threshold']), dtype = int) # 2 because the threshold is for step, not module
    #    force_len = force_fresh_mask.sum(dim=1)
    #    force_indices = force_fresh_mask.argsort(dim = -1, descending = True)[:, :force_len.min()]
    #    force_indices = force_indices[:, torch.randperm(force_indices.shape[1])]

    # Just see more explanation in the version of DiT-ToCa if needed.

    if cache_dic["cache_type"] == "random":
        score = torch.rand(tokens.shape[0], tokens.shape[1], device=tokens.device)

    elif cache_dic["cache_type"] == "straight":
        score = torch.ones(tokens.shape[0], tokens.shape[1]).to(tokens.device)

    elif cache_dic["cache_type"] == "attention":
        # cache_dic['attn_map'][step][layer] (B, N, N), the last dimention has get softmaxed
        score = attn_score(cache_dic, current)
        # score = score + 0.0 * torch.rand_like(score, device= score.device)

    elif cache_dic["cache_type"] == "similarity":
        score = similarity_score(cache_dic, current, tokens)

    elif cache_dic["cache_type"] == "norm":
        score = norm_score(cache_dic, current, tokens)

    elif cache_dic["cache_type"] == "k-norm":
        score = k_norm_score(cache_dic, current)

    elif cache_dic["cache_type"] == "v-norm":
        score = v_norm_score(cache_dic, current)

    elif cache_dic["cache_type"] == "compress":
        score1 = torch.rand(int(tokens.shape[0] * 0.5), tokens.shape[1])
        score1 = torch.cat([score1, score1], dim=0).to(tokens.device)
        score2 = cache_dic["attn_map"][-1][current["layer"]].sum(dim=1)  # .mean(dim=0) # (B, N)
        # normalize
        score2 = score2 / score2.max(dim=1, keepdim=True)[0]
        score = 0.5 * score1 + 0.5 * score2

    # abandoned the branch, if you want to explore the local force fresh strategy, this may help.
    # if ((not current['is_force_fresh']) and (cache_dic['force_fresh'] == 'local')): # current['is_force_fresh'] is False, cause when it is True, no cut and fresh are needed
    #        #print(torch.ones_like(force_indices, dtype=float, device=force_indices.device).dtype)
    #    score.scatter_(dim=1, index=force_indices, src=torch.ones_like(force_indices, dtype=torch.float32,
    #                                                                       device=force_indices.device))

    if True and (cache_dic["force_fresh"] == "global"):
        soft_step_score = (
            cache_dic["cache_index"][-1][current["layer"]][current["module"]].float()
            / (cache_dic["fresh_threshold"])
        )
        # soft_layer_score = cache_dic['cache_index']['layer_index'][current['module']].float() / (27)
        score = score + cache_dic["soft_fresh_weight"] * soft_step_score  # + 0.1 *soft_layer_score

    return score.to(tokens.device)
