from BAD.eval.eval import evaluate
import torch
from BAD.utils import cosine_similaruty, clear_memory
from BAD.validate import get_models_scores

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

def get_l2(model, dataloader, attack=None, use_in=True, progress=False, normalize_features=False):
    
    mean_initial_features = get_features_mean_dict(dataloader,
                                                   feature_extractor=lambda data, targets: model.get_features(data, normalize_features),
                                                   progress=progress)
    mean_in_initial_features = mean_initial_features[1]
    mean_out_initial_features = mean_initial_features[0]

    initial_diff = (mean_out_initial_features - mean_in_initial_features)
    
    def get_adv_feature_extractor(attack):
        return lambda data, targets : model.get_features(attack(data, targets), normalize_features)
    
    mean_adv_features = get_features_mean_dict(dataloader, get_adv_feature_extractor(attack), progress=progress)
    mean_in_adv_features = mean_adv_features[1]
    mean_out_adv_features = mean_adv_features[0]
    
    if use_in:
        adv_diff = (mean_out_adv_features - mean_in_adv_features)
        score1 = norm(adv_diff - initial_diff)
        
        # score2 = cosine_similaruty(adv_diff, initial_diff)
        # return score1, score2
        
        score = score1
        
        return score
    else:
        diff = mean_out_adv_features - mean_out_initial_features
        score = norm(diff)
        return score

