import csv
import math
import torch
import numpy as np
import torch.nn as nn
import scipy.stats as st
import torch.nn.functional as F


# 加载图片路径命名和对应的类别
def get_truth_info(csv_file):
    img_paths = []   # 记录攻击图片的path
    labels = []   # 每个id对应的真实类别
    with open(csv_file) as f:
        lines = csv.DictReader(f, delimiter=',')
        for line in lines:
            img_paths.append(line['ImageId'])    # 图片名称
            labels.append(line['TrueLabel'])     # 该图片对应的label
    return img_paths, labels


# 图片标准化处理(不针对对抗训练模型)
class imgnormalize(nn.Module):
    def __init__(self):
        super(imgnormalize, self).__init__()
        self.mean = [0.485, 0.456, 0.406]  # 均值
        self.std = [0.229, 0.224, 0.225]   # 标准差

    '''
    return (x - self.mean.type_as(x)[None, :, None, None]) / self.std.type_as(x)[None, :, None, None]
    '''
    def forward(self, x):
        for i in range(len(self.mean)):
            x[i] = (x[i]-self.mean[i])/self.std[i]
        return x


"""Translation-Invariant https://arxiv.org/abs/1904.02884"""
def gkern(kernlen=15, nsig=3):
    x = np.linspace(-nsig, nsig, kernlen)
    kern1d = st.norm.pdf(x)
    kernel_raw = np.outer(kern1d, kern1d)
    kernel = kernel_raw / kernel_raw.sum()
    kernel = kernel.astype(np.float32)
    gaussian_kernel = np.stack([kernel, kernel, kernel])  # 5*5*3
    gaussian_kernel = np.expand_dims(gaussian_kernel, 1)  # 1*5*5*3
    gaussian_kernel = torch.from_numpy(gaussian_kernel).cuda()  # tensor and cuda
    return gaussian_kernel


"""Input diversity: https://arxiv.org/abs/1803.06978"""
def DI(x, resize_rate=1.15, diversity_prob=0.5):
    assert resize_rate >= 1.0
    assert diversity_prob >= 0.0 and diversity_prob <= 1.0
    img_size = x.shape[-1]
    img_resize = int(img_size * resize_rate)
    rnd = torch.randint(low=img_size, high=img_resize, size=(1,), dtype=torch.int32)
    rescaled = F.interpolate(x, size=[rnd, rnd], mode='bilinear', align_corners=False)
    h_rem = img_resize - rnd
    w_rem = img_resize - rnd
    pad_top = torch.randint(low=0, high=h_rem.item(), size=(1,), dtype=torch.int32)
    pad_bottom = h_rem - pad_top
    pad_left = torch.randint(low=0, high=w_rem.item(), size=(1,), dtype=torch.int32)
    pad_right = w_rem - pad_left
    padded = F.pad(rescaled, [pad_left.item(), pad_right.item(), pad_top.item(), pad_bottom.item()], value=0)
    ret = padded if torch.rand(1) < diversity_prob else x
    return ret