import torch
import os, argparse
import numpy as np
import json
import utils
import shutil
import math

from attack import *

from research_pool.config import att_pretty
from victim_models.utils import init_classifier
from generator.utils import init_generator

parser = argparse.ArgumentParser()
parser.add_argument('--config', default='victim_models/config/train_Imagenet_Resnet50.json', help='config file')
parser.add_argument('--gpuid', nargs='+', type=str, default="0")
parser.add_argument('--select_images', nargs='+', type=str, default=[], help='Indices of images to select')
parser.add_argument('--clean', action='store_true', help='Clean some of the previous experiment byproducts if found.')

args = vars(parser.parse_args())
os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(args['gpuid']) if type(args['gpuid']) is list else args['gpuid']

with open(args['config']) as config_file:
    state = json.load(config_file)

np.random.seed(state['seed'])
torch.manual_seed(state['seed'])

# L2 or Linf
order = 2 if state['order'] == 2 else np.inf

if "run_id" in list(state.keys()):
    run_id = state["run_id"] if state["run_id"] != "" else utils.get_time_stamp()
    run_id = str(run_id)
else:
    run_id = utils.get_time_stamp()

state['save_dir'] = os.path.join(state['save_dir'], state['attack'], state['dataset'], run_id)
if not os.path.exists(state['save_dir']):
    os.makedirs(state['save_dir'])

config_basename = args['config'].split('/')[-1]
shutil.copyfile(args['config'], os.path.join(state['save_dir'], config_basename))


# Load victim models and generator's (adversary's) data. 
model_wrapper, gen_dataset, target_loader = init_classifier(state)

attack_map = {
    "PGD":PGD,
    "Sign_OPT": OPT_attack_sign_SGD,
    "Sign_OPT_lf": OPT_attack_sign_SGD_lf,
    "CW": CW,
    "HLM_OPT_attack": HLM_OPT_attack,
    "HLM_Sign_OPT": HLM_OPT_attack_sign_SGD,
    "HLM_RayS": HLM_RayS,
    "HLMasdf_RayS": HLMasdf_RayS,
    "HLM_HSJA": HLM_HSJA,
    "Sampling_OPT_attack": Sampling_OPT_attack,
    "Sampling_Sign_OPT": Sampling_OPT_attack_sign_SGD,
    "Sampling_RayS": Sampling_RayS,
    "Sampling_HSJA": Sampling_HSJA,
    "RandSampling_OPT_attack": RandSampling_OPT_attack,
    "RandSampling_Sign_OPT": RandSampling_OPT_attack_sign_SGD,
    "RandSampling_HSJA": RandSampling_HSJA,
    "OPT_attack": OPT_attack,
    "HSJA": HSJA,
    "OPT_attack_lf": OPT_attack_lf,
    "FGSM": FGSM,
    "NES": NES,
    "Bandit": Bandit,
    "NATTACK": NATTACK,
    "Sign_SGD": Sign_SGD,
    "ZOO": ZOO,
    "Liu": OPT_attack_sign_SGD_v2,
    "Evolutionary": Evolutionary,
    "RayS": RayS
}

paper_attacks = list(att_pretty.keys())
    

possible_ix = []
class_conditional = False
if state['hlm_architecture'] in ['CC-AE', 'CC-VAE']:
    class_conditional = True
    possible_ix = np.load(state['classes_path'])
    

kwargs = {}
if "HSJA" in state['attack']:
    kwargs['gamma'] = state['gamma']
    
    
if "HLM" in state['attack']:
    # HLM attacks
    assert state['enc_path'] is not None, f"HLM attack requires enc_path argument."
    assert state['dec_path'] is not None, f"HLM attack requires dec_path argument."
    
    enc, dec = init_generator(state, variant=state['hlm_architecture'])
    attack = attack_map[state['attack']](model_wrapper, enc, dec, 
                                         dataset=state["dataset"], 
                                         early_stopping=state['early_stopping'], 
                                         order=order, **kwargs)
elif "Sampling_RayS" == state['attack']:
    attack = attack_map[state['attack']](model_wrapper, 
                                     dataset=state["dataset"], 
                                     early_stopping=state['early_stopping'], 
                                     order=order, 
                                     a=state["a"], b=state["b"], **kwargs)
elif "RandSampling_" in state['attack']:
    # HLM sampling baselines
    enc, dec = init_generator(state, variant='randsampling')
    attack = attack_map[state['attack']](model_wrapper, enc, dec, 
                                         dataset=state["dataset"], 
                                         early_stopping=state['early_stopping'], 
                                         order=order, **kwargs)
elif "Sampling_" in state['attack']:
    # HLM sampling baselines
    enc, dec = init_generator(state, variant='sampling')
    attack = attack_map[state['attack']](model_wrapper, enc, dec, 
                                         dataset=state["dataset"], 
                                         early_stopping=state['early_stopping'], 
                                         order=order, **kwargs)
else:
    attack = attack_map[state['attack']](model_wrapper, 
                                         dataset=state["dataset"], 
                                         early_stopping=state['early_stopping'], 
                                         order=order, **kwargs)


total_r_count = 0
total_clean_count = 0


# =================== prepare ===================

logs_path = os.path.join(state['save_dir'], f"{state['dataset']}-{state['attack']}_log.npy")
all_logs_path = os.path.join(state['save_dir'], f"{state['dataset']}-{state['attack']}_all_log.npy")
succ_path = os.path.join(state['save_dir'], "successful_indices.npy")

if args['clean']:
    for path in (logs_path, all_logs_path, succ_path):
        if os.path.exists(path):
            os.remove(path)
            print(f"\tCLEAN: removed {path}")

if os.path.exists(logs_path):
    # load all logs path so we can start at correct sum. 
    # ei will take care of updating successful indices.        
    logs = utils.safe_load(all_logs_path)
    
    print("\tRESTART: Logs reload was successful.")
    logs = torch.from_numpy(logs)
else:
    logs = torch.zeros(state['query_limit'], 2)

# Always remove
canary_path = os.path.join(state['save_dir'], "canary.txt")
if os.path.exists(canary_path):
    os.remove(canary_path)


if os.path.exists(succ_path):
    restart_indices = list(utils.safe_load(succ_path))
    
    print(f"Last known state {logs / len(restart_indices)}")
else:
    restart_indices = []

    
ei = utils.EligibleIndex(state, gen_dataset, model_wrapper, restart_indices)


# =================== start ===================

while not ei.stopping_condition():
    next_ix = ei.pop()
    if next_ix is None:
        continue
        
    model_wrapper.reset_queries()

    xi, yi = utils.load_sample(gen_dataset, next_ix)

    if class_conditional and int(yi.item()) not in possible_ix:
        continue

    if torch.cuda.is_available():
        xi, yi = xi.cuda(), yi.cuda()

    print(f"image: {ei.status()}")

    if state['targeted']:
        target = np.random.randint(model_wrapper.num_classes) * torch.ones(yi.shape, dtype=torch.long).cpu()
        while target and torch.sum(target == yi.cpu()) > 0:
            print('re-generate target label')
            target = np.random.randint(model_wrapper.num_classes) * torch.ones(len(xi), dtype=torch.long).cpu()
    else:
        target = None

    adv = attack(xi, yi, 
                 query_limit=state['query_limit'], 
                 epsilon=state['epsilon'], 
                 target=target, 
                 target_loader=target_loader)

    if adv is None:
        # Failed to classify x0 or initialize
        continue

    ei.step(next_ix)

    yi_str = str(yi.item())
    dec_str = str(model_wrapper.predict_label(adv).item())
    yi_folder = os.path.join(state['save_dir'], 'images', yi_str, f'iter={next_ix}')
    if not os.path.exists(yi_folder):
        os.makedirs(yi_folder)

    benign_path = os.path.join(yi_folder, f'benign.npy')
    np.save(benign_path, xi.cpu().numpy())

    adv_path = os.path.join(yi_folder, f'adv_{yi_str}->{dec_str}.npy')
    np.save(adv_path, adv.cpu().detach().numpy())

    trajectory = attack.get_trajectory()
    traj_str = os.path.join(yi_folder, f"trajectory.npy")
    np.save(traj_str, trajectory)

    gradients = attack.get_gradients()
    grad_str = os.path.join(yi_folder, f"gradients.npy")
    np.save(grad_str, gradients)
    print(f"Wrote (or updated) sample files at {yi_folder}.")
    
    np.save(succ_path, np.asarray(ei.successful_indices))

    # if state['targeted'] == False:
    #    r_count = (model_wrapper.predict_label(adv) == yi).nonzero().shape[0]
    #    clean_count = (model_wrapper.predict_label(xi) == yi).nonzero().shape[0]
    #    total_r_count += r_count
    #    total_clean_count += clean_count
    
    logs += attack.get_log()
    print(f"Wrote (or updated) logs file at {logs_path}.")
    print(f"\tVISUAL CHECK: {(logs / len(ei.successful_indices)).numpy()[:10, 0]}...")
    np.save(logs_path, (logs / len(ei.successful_indices)).numpy())
    # Un normalized version for restart
    np.save(all_logs_path, logs.numpy())
            

if state['attack'] in paper_attacks:
    logs /= len(ei.successful_indices)
    print("saving final logs to numpy array")
    np.save(logs_path, np.asarray(logs))
    import matplotlib.pyplot as plt
    plot_log = np.load(logs_path)
    plt.plot(plot_log[:, 0])
    plt.ylabel('Distortion')
    plt.xlabel('Num of queries')
    # plt.show()
    plt.savefig(os.path.join(state['save_dir'], f"{state['dataset']}-{state['test_batch']}-{state['attack']}_plot.png"))

else:
    num_queries = model_wrapper.get_num_queries()
    print(f"clean count:{total_clean_count}")
    print(f"acc under attack count:{total_r_count}")
    print(f"number queries used:{num_queries}")

        
with open(canary_path, 'w') as cf:
    cf.write("We made it.\n")
