import os
import cv2
import torch
import numpy as np
import time
import joblib
from config import Config
from PIL import Image
from matplotlib import pyplot as plt
from torch.nn import DataParallel
import torch.nn.functional as F
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, SubsetRandomSampler
from facenet_pytorch import MTCNN, InceptionResnetV1
import argparse
from models import *
from attack import stick
from mtcnn_pytorch_master.test import crop_face


trans = transforms.Compose([
                transforms.ToTensor(),
            ])

inputsize = {'arcface50':[112,112],'cosface50':[112,112],'arcface34':[112,112],'cosface34':[112,112],
             'facenet':[160,160], 'mobilefacenet':[112,112]}

def reward_slope(adv_face_ts, params_slove, sticker,device):
    advface_ts = adv_face_ts.to(device)
    x, y = params_slove[0]
    w, h = sticker.size
    advstk_ts = advface_ts[:,:,y:y+h,x:x+w]
    advstk_ts.data = advstk_ts.data.clamp(1/255.,224/255.)
    w = torch.arctanh(2*advstk_ts-1)
    x_wv = 1/2 - (torch.tanh(w)**2)/2
    mean_slope = torch.mean(x_wv)
    
    return mean_slope

def load_model(model_name, device):
    """
        load the model
    """
    if(model_name == 'facenet'):
        resnet = InceptionResnetV1(pretrained='vggface2').eval().to(device)
        return resnet
    elif (model_name == '...'):
        pass
        
def load_anchors(model_name, device, target):

    # A PKL file that stores all of the identity feature embeddings in the face database
    anchor_embeddings =  joblib.load('./stmodels/{}/embeddings_{}_5752.pkl'.format(model_name,model_name))
    anchor = anchor_embeddings[target:target+1]
    anchor = anchor.to(device)
    return anchor

def make_stmask(face,sticker,x,y):
    w,h = face.size
    mask = stick.make_masktensor(w,h,sticker,x,y)
    return mask

def crop_imgs(imgs,w,h):
    crops_result = []
    crops_tensor = []
    for i in range(len(imgs)):
        crop = crop_face(imgs[i],w,h)
        crop_ts = trans(crop)
        crops_result.append(crop)
        crops_tensor.append(crop_ts)
    return crops_result, crops_tensor

def cosin_metric(prd,src,device):
    nlen = len(prd)
    mlt = torch.zeros((nlen,1)).to(device)
    src_t = torch.t(src)
    for i in range(nlen):
        mlt[i] = torch.mm(torch.unsqueeze(prd[i],0),torch.unsqueeze(src_t[:,i],1))
    norm_x1 = torch.norm(prd,dim=1)
    norm_x1 = torch.unsqueeze(norm_x1,1)
    norm_x2 = torch.norm(src_t,dim=0)
    norm_x2 = torch.unsqueeze(norm_x2,1)
    denominator = torch.mul(norm_x1, norm_x2)
    metrics = torch.mul(mlt,1/denominator)
    return metrics

def miattack_face(params_slove, model_names,
                  img, label, target, device, sticker,
                  width, height, emp_iterations, adv_img_folder, targeted = True):
    x, y = params_slove[0]
    weights = params_slove[1]
    epsilon = params_slove[2]
    flag = 1 if targeted else -1
    w,h = img.size
    if(w!=width or h!=height):
        crops_result, crops_tensor = crop_imgs([img], width, height)
    else:
        crops_result = [img]
        crops_tensor = [trans(img)]
    X_ori = torch.stack(crops_tensor).to(device)
    delta = torch.zeros_like(X_ori,requires_grad=True).to(device)
    
    fr_models, anchors = [], []
    for name in model_names:
        model = load_model(name, device)
        anchor = load_anchors(name, device, target)
        fr_models.append(model)
        anchors.append(anchor)
        
    mask = make_stmask(crops_result[0],sticker,x,y)
    grad_momentum = 0
    for itr in range(emp_iterations):   # iterations in the generation of adversarial examples
        X_adv = X_ori + delta
        X_adv.retain_grad()
        accm = 0
        print('---iter {}---'.format(itr),end=' ')
        for (i, name) in enumerate(model_names):
            X_op = nn.functional.interpolate(X_adv, (inputsize[name][0], inputsize[name][1]), mode='bilinear', align_corners=False)
            feature = fr_models[i](X_op)
            l_sim = cosin_metric(feature,anchors[i],device)
            print(name,':','{:.4f}'.format(l_sim.item()),end=' ')
            accm += l_sim * weights[i]
        #print('---iter {} interval {}--- loss = {}'.format(itr,t,loss))
        slope = reward_slope(X_adv,params_slove,sticker,device)
        loss = flag * accm + 0.1*slope
        print('L_sim = {:.4f},L_slope = {:.4f}'.format(flag * accm.item(),slope.item()),end='\r')
        loss.backward()
        
        # MI operation
        grad_c = X_adv.grad.clone()                        
        grad_a = grad_c / torch.mean(torch.abs(grad_c), (1, 2, 3), keepdim=True)+1.0*grad_momentum   # 1
        grad_momentum = grad_a
            
        X_adv.grad.zero_()
        X_adv.data=X_adv.data+epsilon * torch.sign(grad_momentum)* mask.to(device)
        X_adv.data=X_adv.data.clamp(0,1)
        delta.data=X_adv-X_ori

    adv_face_ts = (X_ori+delta).cpu().detach()
    adv_final = (X_ori+delta)[0].cpu().detach().numpy()
    adv_final = (adv_final*255).astype(np.uint8)
    localtime2 = time.asctime( time.localtime(time.time()) )
    file_path = os.path.join(adv_img_folder, '{}.jpg'.format(localtime2))
    adv_x_255 = np.transpose(adv_final, (1, 2, 0))
    im = Image.fromarray(adv_x_255)
    
    return adv_face_ts,im,mask

