import numpy as np
import dlib
from cv2 import cv2
from PIL import Image
from torchvision import datasets
import copy
import joblib
import torch
import torch.nn as nn
from attack.tiattack import load_model,cosin_metric
from torchvision import transforms
from torch.utils.data import DataLoader

inputsize = {'arcface34':[112,112],'arcface50':[112,112],'cosface34':[112,112],'cosface50':[112,112],
             'facenet':[160,160], 'mobilefacenet':[112,112]}
           
trans = transforms.Compose([
                transforms.ToTensor(),
            ])

def cosin_all(feature,model_name,device):
    """
        compute the similarity distance between the current face feature 
        and all the identities in the face database.
    """
    
    #A PKL file that stores all of the identity feature embeddings in the face database
    embedding_sets = joblib.load('./stmodels/{}/embeddings_{}_5752.pkl'.format(model_name,model_name))
    
    sets = torch.t(embedding_sets).to(device)
    #print(embedding.shape,sets.shape)
    numerator = torch.mm(feature,sets)
    norm_x1 = torch.norm(feature,dim=1)
    norm_x1 = torch.unsqueeze(norm_x1,1)
    norm_x2 = torch.norm(sets,dim=0)
    norm_x2 = torch.unsqueeze(norm_x2,0)
    #print('norm_x1,norm_x2 ',norm_x1.shape,norm_x2.shape)
    denominator = torch.mm(norm_x1, norm_x2)
    metrics = torch.mul(numerator,1/denominator)
    return metrics.cpu().detach()
               
def load_anchors(model_name, device, target):
    anchor_embeddings =  joblib.load('./stmodels/{}/embeddings_{}_5752.pkl'.format(model_name,model_name))
    anchor = anchor_embeddings[target:target+1]
    anchor = anchor.to(device)
    return anchor

def reward_output(adv_face_ts, threat_model, threat_name, target, device):

    threat = threat_model.to(device)
    advface_ts = adv_face_ts.to(device)
    X_op = nn.functional.interpolate(advface_ts, (inputsize[threat_name][0], inputsize[threat_name][1]), mode='bilinear', align_corners=False)
    feature = threat(X_op)

    anchor = load_anchors(threat_name, device, target)
    l_sim = cosin_metric(feature,anchor,device).cpu().detach().item()
    return l_sim

def reward_slope(adv_face_ts, params_slove, sticker,device):
    advface_ts = adv_face_ts.to(device)
    x, y = params_slove[0]
    w, h = sticker.size
    advstk_ts = advface_ts[:,:,y:y+h,x:x+w]
    advstk_ts.data = advstk_ts.data.clamp(1/255.,224/255.)
    w = torch.arctanh(2*advstk_ts-1)
    x_wv = 1/2 - (torch.tanh(w)**2)/2
    mean_slope = torch.mean(x_wv)
    #print(w,x_wv)
    return mean_slope
    

def check_all(adv_face_ts, threat_model, threat_name, device):
    
    percent = []
    typess = []

    threat = threat_model.to(device)
    threat.eval()
    def collate_fn(x):
        return x
    loader = DataLoader(
        adv_face_ts,
        batch_size=55,
        shuffle=False,
        collate_fn=collate_fn
    )

    for X in loader:
        advface_ts = torch.stack(X).to(device)
        X_op = nn.functional.interpolate(advface_ts, (inputsize[threat_name][0], inputsize[threat_name][1]), mode='bilinear', align_corners=False)
        feature = threat(X_op)
        for i in range(len(feature)):
            sim_all = cosin_all(torch.unsqueeze(feature[i],0),threat_name,device)
            _, indices = torch.sort(sim_all, dim=1, descending=True)
            cla = [indices[0][0].item(),indices[0][1].item(),indices[0][2].item(),\
                indices[0][3].item(),indices[0][4].item(),indices[0][5].item(),indices[0][6].item()]
            typess.append(cla)
            tage = sim_all[0].numpy()
            percent.append(tage)
    return typess,np.array(percent)

