import numpy as np
import os
import cv2

import torch
from torch.autograd import Variable
from torch.utils.data import IterableDataset, DataLoader, Dataset
from torchvision import transforms

from corruptions import gaussian_noise
from corruptions import shot_noise
from corruptions import impulse_noise
from corruptions import defocus_blur
from corruptions import glass_blur
from corruptions import motion_blur
from corruptions import zoom_blur
from corruptions import snow
from corruptions import frost
from corruptions import fog
from corruptions import brightness
from corruptions import contrast
from corruptions import elastic_transform
from corruptions import pixelate
from corruptions import jpeg_compression

import pandas as pd



def ensure_path(folderName):
    if not os.path.exists(folderName):
        os.makedirs(folderName)


class CustomTensorDataset(Dataset):
    def __init__(self, data_tensor):
        self.data_tensor = data_tensor

    def __getitem__(self, index):
        return self.data_tensor[index]

    def __len__(self):
        return self.data_tensor.size(0)


def get_dataloader(batch_size : int = 100,
                   num_workers: int = 2):
    print("Load traffic data")
    data_path = '/data/datasets/traffic/traffic_10x8x100/train/traffic_8000_3x128x128.npz'
    data_zip = np.load(data_path)
    imgs = data_zip['imgs']
    factor_sizes = data_zip['latent_sizes']
    factor_bases = np.prod(factor_sizes) / np.cumprod(factor_sizes)
    imgs_tensor = torch.from_numpy(imgs).float()
    
    dataset = CustomTensorDataset(imgs_tensor)
    data_loader = DataLoader(dataset,
                            batch_size=batch_size,
                            shuffle=True,
                            num_workers=num_workers,
                            pin_memory=True,
                            drop_last=True)
    return data_loader


severity_map = {
    'gaussian_noise' : 5, 
    'shot_noise' : 5, 
    'impulse_noise' : 5, 


    'glass_blur' : 5, 
    'defocus_blur' : 5, 
    'motion_blur': 5, 
    'zoom_blur' : 5, 

    'fog': 5,
    'frost': 5,
    'snow': 5,
    'contrast' : 6, 
    'brightness' : 8, 
    'elastic_transform' : 5,
    
    'jpeg_compression' : 5,
    'pixelate' : 7,
    # 'pgd_attack_random' : None,
    # 'ROA' : None
}



sign = 'real_deercrossing'

if sign == 'stopsign':
    img_root = '/data/datasets/traffic/traffic_10x8x100/train/stopsign'
    file = '449_col4_scal0_rot49_pos0.png'
    labels = [6,0,4,49]
elif sign == 'deerCrossing':
    img_root = '/data/datasets/traffic/traffic_10x8x100/train/deerCrossing'
    file = '49_col0_scal0_rot49_pos0.png'
    labels = [0,2,0,49]
elif sign == 'real_deercrossing':
    img_root = '../../real_data/deercrossing'
    file = 'deercrossing0_blue_0.png'
    labels = [0,2,0,49]

files = os.listdir(img_root)

convert_img = transforms.Compose([transforms.ToPILImage()])

img = cv2.imread(os.path.join(img_root, file))
img_half = cv2.resize(img, (128, 128))
# print(img_half.shape)
# x = np.transpose(img_half, (2, 0, 1))
x = convert_img(img_half)
# print(x.shape)

z_dict = {}

for phase in severity_map.keys():
    print('>> Phase : %s' % phase)
    
    prefix = file.split('.')[0]
    save_root = os.path.join('imgs', prefix)
    ensure_path(save_root)
    save_path = os.path.join(save_root, f'{phase}.png')
    print(save_path)
    
    cv2.imwrite(os.path.join(save_root, 'original.png'), np.array(x))
    
    corruptor = eval(phase)
    severity = severity_map[phase]
    
    # Implement attack
    x_adv = corruptor(x, severity)
    if phase in ['jpeg_compression', 'pixelate']:
        x_adv = np.array(x_adv)

    # print(x_adv)
    # print(x_adv.shape)
    # input()
    cv2.imwrite(save_path, x_adv)
    
    # # test csv file
    # csv_path = f'./attacked_z/val_z_label_{phase}.csv'
    # df = pd.read_csv(csv_path)
    
    # z = df.loc[(df['class_label'] ==labels[0]) & \
    #            (df['shape_label'] ==labels[1]) & \
    #            (df['color_label'] ==labels[2]) & \
    #            (df['rotate_label']==labels[3])]
    # z_dict[phase] = z['z'].tolist()[0]
    # print("z:", z['z'].tolist()[0])




# Convert all values to columns
phase_in_order = [
    'impulse_noise',
    'gaussian_noise',
    'shot_noise',
    'defocus_blur',
    'glass_blur',
    'motion_blur',
    'zoom_blur',
    'brightness',
    'snow',
    'frost',
    'fog',
    'contrast',
    'elastic_transform',
    'pixelate',
    'jpeg_compression'
    ]
# columns = ['0', '1', '2 (color)', '3', '4', '5', '6', '7', '8 (orient)', '9 (shape)']
# z_array = []
# for phase in z_dict.keys():
#     z_array.append(z_dict[phase])
# print(z_array)

# for i in files:
#     img = cv2.imread(os.path.join(img_root, i))
#     img = np.expand_dims(img, 0) / 255.
#     x = np.transpose(img, (0, 3, 1, 2))
#     print(x.shape)
    
#     adv_img = gaussian_noise(x*255, 5)
    
#     cv2.imwrite('test.png', np.transpose(adv_img[0], (1, 2, 0)))
#     input()