import os
import cv2
import torch
import numpy as np
import time
import joblib
from config import Config
from PIL import Image
from matplotlib import pyplot as plt
from torch.optim import lr_scheduler
from torch.nn import DataParallel
import torch.nn.functional as F
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, SubsetRandomSampler
from facenet_pytorch import MTCNN, InceptionResnetV1
import argparse
from attack.tiattack import miattack_face,crop_imgs,load_model

from rl_solve.attack_utils import loc_space,agent_output,vector_processor,generate_actions,actions2params
from rl_solve.agent import UNet
from rl_solve.reward import reward_output, reward_slope, check_all
from visualizer import Visualizer


def attack_process(img, sticker, threat_model, threat_name, model_names, label, target, device,
                   width ,height, emp_iterations, adv_img_folder, targeted = True,
                   sapce_thd=50,pg_m=5,max_iter=10000):
   
    '''---------------------Face image initialization-----------------------'''
    crops_result, crops_tensor = crop_imgs([img], width, height)                              
    init_face = crops_result[0]
    clean_ts = torch.stack(crops_tensor).to(device)
    space = loc_space(init_face,sticker,threshold=sapce_thd)                                  # valid pasing area mask
    space_ts = torch.from_numpy(space).to(device)
    
    n_models = len(model_names) 
    sim_labels, sim_probs = check_all(crops_tensor, threat_model, threat_name, device)     
    
    target = target if targeted else label
    start_label = [label, target]
    start_gap = sim_probs[0][start_label]
    print('start_label: {} start_gap: {}'.format(start_label,start_gap))                                          
    
    '''------------------------Agent initialization--------------------------'''
    print('Initializing the agent......')
    agent = UNet(inputdim = init_face.size[0],sgmodel = n_models,feature_dim=20).to(device)  # agent(unet)
    optimizer = torch.optim.Adam(agent.parameters(),lr=1e-03,weight_decay=5e-04)             # optimizer
    scheduler = lr_scheduler.StepLR(optimizer,step_size=5,gamma=0.1)                         # learning rate decay
    baseline = 0.0
    
    '''-----------Initialization of variables that store the results---------'''
    last_score = []                                                                          # predicted similarity
    all_final_params = []
    all_best_reward = -2.0
    all_best_face = init_face
    all_best_adv_face_ts = torch.stack(crops_tensor)
    num_iter = 0                                                                             # RL framework iterations
    while num_iter < max_iter:
        '''--------------------Agent output feature maps-------------------'''    
        featuremap, eps_logits = agent(clean_ts)
        pre_actions = vector_processor(featuremap,eps_logits,space_ts,device)
        cost = 0
        
        '''----------------Policy gradient and Get reward----------------'''
        pg_rewards = []
        phas_final_params = []
        phas_best_reward = -2.0
        phas_best_face = init_face
        phas_best_adv_face_ts = torch.stack(crops_tensor)
        for _ in range(pg_m):
            log_pis, log_sets = 0, []
            actions = generate_actions(pre_actions)         # sampling                           
            for t in range(len(actions)):
                log_prob = pre_actions[t].log_prob(actions[t])
                #print(log_prob)
                log_pis += log_prob
                log_sets.append(log_prob)
            params_slove = actions2params(actions,width)
            adv_face_ts, adv_face, mask = miattack_face(params_slove, model_names,
                                        init_face, label, target, device, sticker,
                                        width, height, emp_iterations, adv_img_folder, targeted = targeted)

            reward_m = reward_output(adv_face_ts,threat_model, threat_name, target, device)
            
            if(not targeted): reward_m = -1*reward_m
            reward_f = reward_m
            expected_reward = log_pis * (reward_f - baseline)
            
            cost -= expected_reward

            pg_rewards.append(reward_m)
            if reward_f > phas_best_reward:
                phas_final_params = params_slove
                phas_best_reward = reward_f
                phas_best_face = adv_face
                phas_best_adv_face_ts = adv_face_ts
        
        observed_value = np.mean(pg_rewards)
        print('\n{}-th: Reward is'.format(num_iter),end=' ')
        for p in range(len(pg_rewards)):
            print('{:.5f}'.format(pg_rewards[p]),end=' ')
        print('avg:{:.5f}'.format(observed_value))
        if opt.display:
            visualizer.display_current_results(num_iter, observed_value, name='cosin similarity')
            visualizer.display_current_results(num_iter, cost.item(), name='loss')

        '''-------------------------Update Agent---------------------------'''
        optimizer.zero_grad()
        cost.backward()
        torch.nn.utils.clip_grad_norm_(agent.parameters(),5.0)
        optimizer.step()
        
        scheduler.step()
        '''-------------------------Check Result---------------------------'''
        localtime2 = time.asctime( time.localtime(time.time()) )
        if phas_best_reward > all_best_reward:
            all_final_params = phas_final_params
            all_best_reward = phas_best_reward
            all_best_face = phas_best_face
            all_best_adv_face_ts = phas_best_adv_face_ts

        sim_labels, sim_probs = check_all(all_best_adv_face_ts, threat_model, threat_name, device)
        
        if ((targeted and sim_labels[0][0] == target) or                              # early stop
            (not targeted and sim_labels[0][0] != target)):
            print('early stop at iterartion {},succ_label={}'.format(num_iter,succ_label))
            return True, num_iter, [all_best_face,all_best_reward,all_final_params,all_best_adv_face_ts]

        last_score.append(observed_value)    
        last_score = last_score[-200:]   
        if last_score[-1] <= last_score[0] and len(last_score) == 200:
            print('FAIL: No Descent, Stop iteration')
            return False, num_iter, [all_best_face,all_best_reward,all_final_params,all_best_adv_face_ts]
        
        num_iter += 1
        
    return False, num_iter, [all_best_face,all_best_reward,all_final_params,all_best_adv_face_ts]


if __name__=="__main__":
    opt = Config()
    localtime1 = time.asctime( time.localtime(time.time()) )
    folder_path = os.path.join(opt.adv_img_folder, localtime1)
    tempsave_path = os.path.join(opt.joblib_folder, localtime1)
    if not os.path.exists(folder_path):
        os.makedirs(folder_path)
    if not os.path.exists(tempsave_path):
        os.makedirs(tempsave_path)
    
    dataset = datasets.ImageFolder('./test_images')   # the path of test image dataset
    dataset.idx_to_class = {i:c for c, i in dataset.class_to_idx.items()}
    
    inds = [6196]                                    # The index number of the image in the dataset
    for i in range(len(inds)):
        idx = inds[i]
        img = dataset[idx][0]
        label = dataset[idx][1]

        threat_model = load_model(opt.threat_name, torch.device('cpu'))
        if opt.display:
            visualizer = Visualizer()
        
        flag,iters,vector = attack_process(img, opt.sticker, threat_model, opt.threat_name, opt.model_names, 
                            label, opt.target, opt.device, opt.width, opt.height, opt.emp_iterations, 
                            folder_path, opt.targeted, opt.sapce_thd, pg_m=5, max_iter=40)
        final_img = vector[0]
        final_params = vector[2]
        final_facets = vector[3]

        file_path = os.path.join(folder_path, '{}_{}.jpg'.format(i,idx))
        final_img.save(file_path,quality=99)

        tempsave_save = [final_facets,final_params]
        joblib.dump(tempsave_save,tempsave_path+'/{}_{}_face&params.pkl'.format(i,idx))
