import torch

from torchvision import transforms

import os
import sys
import cv2
import csv
import random
import argparse
import pandas as pd

import numpy as np

from PIL import Image

sys.path.append(sys.path[0] + '/..')
from utils import str2bool
from solver_128_swap import Solver
# from solver_128 import Solver
from adv_dataset import return_data, get_dataloader

from corruptions import gaussian_noise
from corruptions import shot_noise
from corruptions import impulse_noise
from corruptions import defocus_blur
from corruptions import glass_blur
from corruptions import motion_blur
from corruptions import zoom_blur
from corruptions import snow
from corruptions import frost
from corruptions import fog
from corruptions import brightness
from corruptions import contrast
from corruptions import elastic_transform
from corruptions import pixelate
from corruptions import jpeg_compression
from linf import pgd_attack_random
from ROA import ROA
from model import NEURAL


severity_map = {
    'gaussian_noise' : 5, 
    'shot_noise' : 5, 
    'impulse_noise' : 5, 


    'glass_blur' : 5, 
    'defocus_blur' : 5, 
    'motion_blur': 5, 
    'zoom_blur' : 5, 

    'fog': 5,
    'frost': 5,
    'snow': 5,
    'contrast' : 6, 
    'brightness' : 8, 
    'elastic_transform' : 5,
    
    'jpeg_compression' : 5,
    'pixelate' : 7,
    # 'pgd_attack_random' : None, # Need to specify pgd_ckpt
    # 'ROA' : None # Need to specify pgd_ckpt
}


torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True


def main(args):
    torch.cuda.set_device(args.gpu)
    seed = args.seed
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    device = f"cuda:{args.gpu}"
    
    solver = Solver(args)
    data_loader = return_data(args)
    # data_loader = get_dataloader(batch_size=args.batch_size)
    
    
    print("------- save extect embedding ---------")
    attack = args.attack

    print("write attacked embedding now...")
    solver.net_mode(train=False)
    encoder = solver.net.encoder
    
    attack_dir = './attacked_z_real_signs'
    attack_dir = os.path.join(attack_dir, args.viz_name)
    
    ## for pgd attack ###
    pgd_ckpt = torch.load( "./KEMLP/raw/adv_train_8_ckpt/model_3_adv_acc=0.996094.ckpt" , map_location = device)
    
    model = NEURAL(n_class=8,n_channel=3)
    model = model.to(device)
    model.eval()
    model.load_state_dict(pgd_ckpt)
    convert_img = transforms.Compose([transforms.ToPILImage()])
    transform_pdg1 = transforms.Compose([transforms.ToPILImage(),
                                        transforms.Resize((32, 32)),
                                        transforms.ToTensor()])
    transform_pdg2 = transforms.Compose([transforms.ToPILImage(),
                                        transforms.Resize((128, 128)),
                                        transforms.ToTensor()])
    
    for phase in severity_map.keys():
        save_path = os.path.join(attack_dir, f'val_z_label_{phase}.csv')
        print('>> Phase : %s' % phase)
        corruptor = eval(phase)
        severity = severity_map[phase]
        
        if phase == 'ROA':
            attacker = corruptor(base_classifier=model, size=32, device=device)
    
        with open(save_path, 'w', encoding='UTF8') as f_in:
            writer = csv.writer(f_in)
            
            headers = ['z','class_label','shape_label', 'color_label', 'rotate_label']
            writer = csv.DictWriter(f_in, fieldnames=headers)
            writer.writeheader()
            for x, class_label, shape_label, color_label, rotate_label in data_loader:
                # x = Variable(x).to(solver.device)
                img = convert_img(x[0])

                if phase == 'pgd_attack_random':
                    x = transform_pdg1(x[0])
                    x = torch.unsqueeze(x, 0).to(device)
                    x_adv = corruptor(model,x,class_label,eps=2/255.0,alpha=1/255,iters=40,randomize=True,gpu_id=args.gpu)
                    x_adv = transform_pdg2(x_adv[0])
                    x_adv = torch.unsqueeze(x_adv, 0).to(device)
                elif phase == 'ROA':
                    x = transform_pdg1(x[0])
                    x = torch.unsqueeze(x, 0).to(device)
                    # class_label.to(device)
                    x_adv = attacker.exhaustive_search(x,class_label,0.05,30,5,5,2,2,False)
                    x_adv = transform_pdg2(x_adv[0])
                    x_adv = torch.unsqueeze(x_adv, 0).to(device)
                else: # corruptions
                    x_adv = corruptor(img, severity)
                    if phase in ['jpeg_compression', 'pixelate']:
                        x_adv = np.array(x_adv)
                    x_adv = torch.from_numpy(x_adv / 255.).to(dtype=torch.float32).to(solver.device)
                    x_adv = torch.unsqueeze(x_adv.permute(2, 0, 1), 0)
                
                # print(x_adv.shape)
                # cv2.imwrite('test_adv.png', x_adv)
                # input()
                # print(x_adv.shape)

                shape_label = shape_label.numpy()[0]
                color_label = color_label.numpy()[0]
                class_label = class_label.numpy()[0]
                rotate_label = rotate_label.numpy()[0]
                
                z = encoder(x_adv)[:, :args.z_dim]
                z_vec = z.data.cpu().numpy()
                z_vec = z_vec[0].tolist()
                
                writer.writerow({'z':z_vec,
                                'class_label':class_label,
                                'shape_label':shape_label,
                                'color_label':color_label,
                                'rotate_label':rotate_label
                                })
        print("***---save embedding to file---***")


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='toy Beta-VAE')

    parser.add_argument('--train', default=False, type=str2bool, help='train or traverse')
    parser.add_argument('--seed', default=1, type=int, help='random seed')
    parser.add_argument('--cuda', default=True, type=str2bool, help='enable cuda')
    parser.add_argument('--gpu', default=0, type=int, help='gpu id')
    parser.add_argument('--C_start', default=0, type=float, help='start value of C')

    parser.add_argument('--max_iter', default=1e6, type=float, help='maximum training iteration')
    parser.add_argument('--batch_size', default=64, type=int, help='batch size')
    parser.add_argument('--limit', default=3, type=float, help='traverse limits')
    parser.add_argument('--inter', default=2/3, type=float, help='intercept')
    parser.add_argument('--KL_loss', default=25, type=float, help='KL_divergence')
    parser.add_argument('--step_val', default=0.15, type=float, help='step_val')
    parser.add_argument('--pid_fixed', default=False, type=str2bool, help='if fixed PID or dynamic')
    parser.add_argument('--is_PID', default=True, type=str2bool, help='if use pid or not')
    
    parser.add_argument('--z_dim', default=10, type=int, help='dimension of the representation z')
    parser.add_argument('--beta', default=4, type=float, help='beta parameter for KL-term in original beta-VAE')
    parser.add_argument('--objective', default='H', type=str, help='beta-vae objective proposed in Higgins et al. or Burgess et al. H/B')
    parser.add_argument('--model', default='H', type=str, help='model proposed in Higgins et al. or Burgess et al. H/B')
    parser.add_argument('--gamma', default=1000, type=float, help='gamma parameter for KL-term in understanding beta-VAE')
    parser.add_argument('--C_max', default=25, type=float, help='capacity parameter(C) of bottleneck channel')
    parser.add_argument('--C_stop_iter', default=1e5, type=float, help='when to stop increasing the capacity')
    parser.add_argument('--lr', default=1e-4, type=float, help='learning rate')
    parser.add_argument('--beta1', default=0.9, type=float, help='Adam optimizer beta1')
    parser.add_argument('--beta2', default=0.999, type=float, help='Adam optimizer beta2')
    
    parser.add_argument('--dset_dir', default='data', type=str, help='dataset directory')
    parser.add_argument('--dataset', default='traffic', type=str, help='dataset name')
    parser.add_argument('--image_size', default=128, type=int, help='image size. now only (64,64) is supported')
    parser.add_argument('--num_workers', default=16, type=int, help='dataloader num_workers')
    
    parser.add_argument('--viz_on', default=True, type=str2bool, help='enable visdom visualization')
    parser.add_argument('--viz_name', default='main', type=str, help='visdom env name')
    parser.add_argument('--viz_port', default=8090, type=str, help='visdom port number')
    parser.add_argument('--save_output', default=True, type=str2bool, help='save traverse images and gif')
    parser.add_argument('--output_dir', default='outputs', type=str, help='output directory')

    parser.add_argument('--gather_step', default=10000, type=int, help='numer of iterations after which data is gathered for visdom')
    parser.add_argument('--display_step', default=10000, type=int, help='number of iterations after which loss data is printed and visdom is updated')
    parser.add_argument('--save_step', default=10000, type=int, help='number of iterations after which a checkpoint is saved')
    
    parser.add_argument('--ckpt_dir', default='../checkpoints', type=str, help='checkpoint directory')
    parser.add_argument('--ckpt_name', default='last', type=str, help='load previous checkpoint. insert checkpoint filename')
    parser.add_argument('--warmup', default=4000, type=float, help='Warm up iterations.')
    parser.add_argument('--compare_weight', default=0.5, type=float, help='Weight of swapped reconstrucion loss.')
    parser.add_argument('--threshold', default=0.5, type=float, help='Threshold value for filtering latents.')
    parser.add_argument('--semi_percentage', default=0.2, type=float, help='Percentage of semi-supervised pairwise inputs.')
    
    ## Attack z embedding ##
    parser.add_argument('--attack', default='gaussian_noise', type=str, help='Choosing attacks.')
    
    args = parser.parse_args()
    
    print(args)
    
    main(args)
    
    # net = Solver(args)
