import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms

import pilgram
import numpy as np
from tqdm import tqdm
from io import BytesIO
from PIL import Image

import pickle
import numpy as np

## Code modified from:
## https://github.com/Queuecumber/torchjpeg/
class JPEG:
    """
    JPEG compression
    """
    def __init__(self, quality):
        self.q = quality

    def __call__(self, pil_image):
        with BytesIO() as f:
            pil_image.save(f, format="jpeg", quality=self.q)
            f.seek(0)
            output = Image.open(f)
            output.load()

        return output

def jpeg_compress_tf(imgs, q=100):
    """
    Batch JPEG compress for torch FloatTensor images
    """
    trans = transforms.Compose(
        [
        transforms.Normalize((-1, -1, -1), (2, 2, 2)),
        transforms.ToPILImage(),
        JPEG(quality=q),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
    return torch.stack([trans(img) for img in imgs])

def rotate(x, d):
    """
    Image rotation
    """
    d = torch.deg2rad(d)
    theta = torch.stack([torch.cos(d), -torch.sin(d), torch.zeros(x.shape[0]).cuda(),
                         torch.sin(d), torch.cos(d), torch.zeros(x.shape[0]).cuda()], dim=1).view(-1,2,3)    
    grid = F.affine_grid(theta, x.size())
    x = F.grid_sample(x, grid)
    return x

def resize_tf(imgs, scale=1.5):
    """
    Image scaling
    """
    _, _, H, W = imgs.shape
    trans = transforms.Compose(
        [
        transforms.Normalize((-1, -1, -1), (2, 2, 2)),
        transforms.ToPILImage(),
        transforms.Resize((int(H*scale), int(W*scale))),
        transforms.Resize((H, W)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
    return torch.stack([trans(img) for img in imgs])

filters = {
    'clarendon':pilgram.clarendon, 
    'gingham':pilgram.gingham, 
    'moon':pilgram.moon
    }
def filter_tf(imgs, filter='clarendon'):
    """
    Image filtering
    """
    trans = transforms.Compose(
        [
        transforms.Normalize((-1, -1, -1), (2, 2, 2)),
        transforms.ToPILImage(),
        filters[filter],
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
    return torch.stack([trans(img) for img in imgs])

def gen_adv(net, criterion, x, label, steps=5, eps=0.1):
    """
    PGD attack
    """
    noise = torch.zeros_like(x, requires_grad=True)
    for i in range(steps):
        out = net((x+noise).cuda())
        loss = criterion(out, label.cuda())
        if len(loss.shape)>=1:
            loss = torch.mean(loss)
        grad = torch.autograd.grad(loss, noise)[0]
        noise = torch.clamp(torch.tensor(noise + 0.01*torch.sign(grad.detach().cpu()), requires_grad=True), -eps, eps)
    return noise

def gen_adv_target(net, criterion, x, label, target_loss=10, eps=0.1, max_iter=10):
    """
    PGD attack until target loss
    """
    noise = torch.zeros_like(x, requires_grad=True)
    loss = torch.zeros(noise.shape[0]).cuda()
    it = 0
    while torch.any(loss < target_loss) and it < max_iter:
        it += 1
        out = net((x+noise).cuda())
        loss = criterion(out, label.cuda())
        if len(loss.shape)>=1:
            avg_loss = torch.mean(loss)
        grad = torch.autograd.grad(avg_loss, noise)[0]
        grad_mask = (loss < target_loss).view(noise.shape[0],1,1,1)
        grad *= grad_mask.cpu()
        noise = torch.clamp(torch.tensor(noise + 0.01*torch.sign(grad.detach().cpu()), requires_grad=True), -eps, eps)
    return noise

def thres_attack(train, test, steps=20):
    """
    Find the best threshold for training and test statistics
    """
    md_train = np.median(train)
    md_test = np.median(test)
    thres = md_train
    step_size = (md_test - md_train)/steps
    best_thres = 0
    best_acc = 0
    best_acc_train = 0
    best_acc_test = 0
    for i in tqdm(range(steps+1)):
        thres = md_train + i * step_size
        acc_train = sum(np.array(train)<=thres)/len(train)
        acc_test = sum(np.array(test)>thres)/len(test)
        acc = (acc_train + acc_test)/2
        if acc > best_acc:
            best_acc = acc
            best_thres = thres
            best_acc_train = acc_train
            best_acc_test = acc_test
    return best_acc, best_thres, best_acc_train, best_acc_test, md_train, md_test

def cover_diff(train1, test1, train2, test2, thres1=None, thres2=None):
    """
    Coverage difference rate w
    """
    if thres1 is None:
        md_train1 = np.median(train1)
        md_test1 = np.median(test1)
        thres1 = (md_train1 + md_test1) / 2
    
    if thres2 is None:
        md_train2 = np.median(train2)
        md_test2 = np.median(test2)
        thres2 = (md_train2 + md_test2) / 2

    union = np.logical_or(np.array(train1)<=thres1, np.array(train2)<=thres2)
    inter = np.logical_and(np.array(train1)<=thres1, np.array(train2)<=thres2)

    return sum(np.logical_xor(union, inter))/len(train1)

def attack_given_thres(train, test, thres):
    """
    Compute attack accuracy with given threshold
    """
    acc_train = sum(np.array(train)<=thres)/len(train)
    acc_test = sum(np.array(test)>thres)/len(test)
    acc = (acc_train + acc_test)/2
    return acc


### utils for current MIAs
def att_classification(r1, r2):
    """
    CC attack
    """
    acc1 = sum(r1==True)/len(r1)
    acc2 = sum(r2==False)/len(r2)
    return (acc1+acc2)/2, acc1, acc2

def MH(p, lbl):
    """
    Entropy metric for ET attack
    """
    term1 = - torch.gather((1-p)*torch.log(p), dim=1, index=lbl.view(-1,1)).view(-1).numpy()
    a = p*torch.log(1-p)
    term2 = - (torch.sum(a, dim=1).numpy() - torch.gather(a, dim=1, index=lbl.view(-1,1)).view(-1).numpy())
    return np.nan_to_num(term1 + term2, nan=10000)


# utils for NN-attack
class MEMBERINF(nn.Module):
    """
    NN-attack binary classifier
    """
    def __init__(self):
        super(MEMBERINF, self).__init__()
        self.fc1 = nn.Linear(10, 64)
        self.fc2 = nn.Linear(64, 128)
        self.fc3 = nn.Linear(128, 256)
        self.fc4 = nn.Linear(256, 128)
        self.fc5 = nn.Linear(128, 64)
        self.fc6 = nn.Linear(64, 2)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = F.relu(self.fc4(x))
        x = F.relu(self.fc5(x))
        x = self.fc6(x)
        return x

class MEMBERINF_P(nn.Module):
    """
    NN-attack binary classifier for single class
    """
    def __init__(self, input_dim=100):
        super(MEMBERINF_P, self).__init__()
        self.fc1 = nn.Linear(input_dim, 128)
        self.fc2 = nn.Linear(128, 2)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x
        
def separate_dataset(path_list, num_class=10):
    """
    separate logits dataset by class
    """
    cls_dict = {i:[] for i in range(num_class)}
    for path in path_list:
        with open(path, 'rb') as f:
            data = pickle.load(f)
        for p in data:
            label, logits, in_out = p
            cls_dict[label].append((logits, in_out))
    return cls_dict

def collect_dataset(path_list):
    """
    collect dataset from saved logits
    """
    data_list = []
    for path in path_list:
        with open(path, 'rb') as f:
            data = pickle.load(f)
        for p in data:
            label, logits, in_out = p
            data_list.append((logits, in_out))
    return data_list

def get_batch_dataset(data_list, batch_size=64):
    '''
        data_list: [(logits, label)]
    '''
    logits = [i[0] for i in data_list]
    labels = [i[1] for i in data_list]
    logits_batch = [np.stack(logits[i*batch_size:(i+1)*batch_size]) \
                    for i in range(int(len(data_list)/batch_size))]
    labels_batch = [np.stack(labels[i*batch_size:(i+1)*batch_size]) \
                    for i in range(int(len(data_list)/batch_size))]
    dataset = [(logits_batch[i], labels_batch[i]) for i in range(len(logits_batch))]
    train_set = dataset[:int(0.8*len(dataset))]
    test_set = dataset[int(0.8*len(dataset)):]
    return train_set, test_set

def eval_dataset(net, dataloader):
    """
    Evaluate the accuracy of a network
    """
    correct = 0
    total = 0
    with torch.no_grad():
        for data in dataloader:
            data, labels = torch.FloatTensor(data[0]), torch.LongTensor(data[1])
            outputs = net(data)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return 100 * correct / total