import torch
import numpy as np
import warnings

def make_attention_based_mask(attention_weights, opt, epoch):
    batch_size, num_heads, num_features, _ = attention_weights.shape
    
    attention_mean, _ = attention_weights.max(dim=3)  
    feature_importance, _ = attention_mean.max(dim=1)  

    mean = feature_importance.mean(dim=1, keepdim=True)
    std = feature_importance.std(dim=1, keepdim=True) + 1e-6
    feature_importance = (feature_importance-mean)/std
    
    k_percent = min(max(opt['k_aug'], 0), 1)  
    k = max(1, int(k_percent * num_features)) 

    _, top_k_indices = torch.topk(feature_importance, k, dim=1)    

    for i in range(batch_size):
        mask = torch.ones((batch_size, num_features), dtype=torch.bool, device=attention_weights.device)
        mask[i, top_k_indices[i]] = 0  
    
    # retrieval 
    if not opt['score_method'] == "none": # cosine or l2 
        all_neighbors = retrieval_neighbors(feature_importance, opt, epoch)
    
    return mask.long(), all_neighbors


def retrieval_neighbors(feature_importance, opt, epoch):
    n_samples, n_features = feature_importance.shape
    mixed_target_mask = torch.zeros(n_samples, n_features, device=feature_importance.device)

    mixing_num = opt['fusion_num']

    all_neighbors = []
    for target_idx in range(n_samples):
        target_obj = feature_importance[target_idx, :].clone()
        candidates_obj = feature_importance.clone()

        if opt['score_method'] == "cosine":
            score = cosine_similarity(target_obj, candidates_obj)
        elif opt['score_method'] == "l2":
            score = l2_distance(target_obj, candidates_obj)

        neighbors_indices = torch.topk(score, mixing_num+1, largest=True).indices        
        all_neighbors.append(neighbors_indices[1:]) 
        
    return all_neighbors


def cosine_similarity(x, y):
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        
        dot_product = torch.matmul(y, x.T)
        norm_x = torch.norm(x)
        norm_y = torch.norm(y, dim=1)
        similarity = dot_product / (norm_x * norm_y)
    return similarity


def l2_distance(x, y):
    return torch.norm(y - x, dim=1)