import torch
from dynamic_prune.prune_scheduler.base_prune_scheduler import BasePruneScheduler
from dynamic_prune.config import cfg

def _adjust_similarity_threshold(cfg, current_similarity_score):
    cfg['similarity_history'].append(current_similarity_score.item())

    window_size = 5
    if len(cfg['similarity_history']) > window_size:
        cfg['similarity_history'].pop(0)
    
    if len(cfg['similarity_history']) < window_size:
        return 

    percentile = 5
    new_threshold = np.percentile(np.array(cfg['similarity_history']), percentile)
    cfg['similarity_threshold'] = new_threshold


def calc_similarity(f1: torch.tensor, f2: torch.tensor, mode="cosine"): 
    if mode == "norm": 
        return 1 / (1 + torch.mean(torch.norm(f1-f2, p=2, dim=1)))
    elif mode == "cosine": 
        import torch.nn.functional as F
        return torch.mean(F.cosine_similarity(f1, f2, dim=1))
    else: 
        raise NotImplementedError


class SimilarityBasedPruneScheduler(BasePruneScheduler):
    """
    A scheduler that prunes based on similarity metrics.
    This class extends the BasePruneScheduler and implements the pruning logic
    based on similarity scores.
    """
    
    def is_prune(self, projected_patch_embeddings, **kwargs):
        cfg["visual_embedding_list"].append(projected_patch_embeddings[0])

        if 'recent_prune_flags' not in cfg:
            cfg['recent_prune_flags'] = []

        if len(cfg["visual_embedding_list"]) == 1:
            cfg['is_prune'] = True
        else:
            similarity_with_last_prune = calc_similarity(cfg["visual_embedding_list"][-1], cfg["visual_embedding_list"][-2])

            if (similarity_with_last_prune < cfg['similarity_threshold'] and not (len(cfg['recent_prune_flags']) >= 2 and any(cfg['recent_prune_flags'][-2:]))):
                cfg['is_prune'] = True
            else:
                cfg['is_prune'] = False


            cfg['recent_prune_flags'].append(cfg['is_prune'])
            if len(cfg['recent_prune_flags']) > 2:
                cfg['recent_prune_flags'].pop(0)
        if cfg['is_prune']:
            cfg["pre_is_prune"] = cfg['current_action_step']
            cfg['prune_count'] += 1
            print('prune')

        if len(cfg["visual_embedding_list"]) > 1:
            similarity_with_previous_frame = calc_similarity(cfg["visual_embedding_list"][-1], cfg["visual_embedding_list"][-2])
            _adjust_similarity_threshold(cfg, similarity_with_previous_frame)
