import numpy as np
import pickle
from tqdm import tqdm

import torch
import torch.nn as nn

from purchase100 import *

def attack_rd_flip(ff, start, end, labels, criterion, num_run=5, loss_thres=0.5, gamma=0.2):
    """
    Implementation of Algorithm 2, and update functions for missing features
    (random flipping).
    Also saves losses for Algorithm 1.
    """
    flipped_train = []
    losses_train = []
    for k in tqdm(range(start, end)):
        for _ in range(num_run):
            noisy = random_flip(ff[k], gamma)
            # loss1 = criterion(net(ff[k:k+1]), labels[k:k+1])
            loss2 = criterion(net(noisy.view(-1, 600)), labels[k:k+1])
            losses_train.append(loss2.item())

            num_flipped = 0
            loss_new = float('inf')
            flipped_ind = set()
            while loss_new > loss_thres and num_flipped < 600:
                diff = []
                for i in range(600):
                    if i not in flipped_ind:
                        noisy[i] = 1 - noisy[i]
                        loss_new = criterion(net(noisy.view(-1, 600)), labels[k:k+1])
                        diff.append(loss2 - loss_new)
                        noisy[i] = 1 - noisy[i]
                    else:
                        diff.append(-1)
                noisy[np.argmax(diff)] = 1 - noisy[np.argmax(diff)]
                flipped_ind.add(np.argmax(diff))
                num_flipped += 1
            flipped_train.append(num_flipped)
    return (flipped_train, losses_train)

if __name__ == '__main__':
    import argparse
    import os

    parser = argparse.ArgumentParser("Attacks on Purchase-100")
    parser.add_argument("--load_dir", type=str)
    parser.add_argument("--num_run", type=int, default=5)
    parser.add_argument("--split_ind", type=int, default=0)
    parser.add_argument("--loss_thres", type=float, default=0.1)
    parser.add_argument("--gamma", type=float, default=0.2)
    parser.add_argument("--save_dir", type=str)
    parser.add_argument("--file_name", type=str)

    args = parser.parse_args()

    num_run = args.num_run
    ind = args.split_ind
    loss_thres = args.loss_thres
    gamma = args.gamma

    with open('SAVED_INDEX_PURMUTATION', 'rb') as fp:
        perm = pickle.load(fp)

    dataset = Purchase100(perm=perm)

    model_path = os.path.join(args.load_dir, 'MODEL_OF_GIVEN_INDEX_PATH')
    net = Purchase_Net()
    net.load_state_dict(torch.load(model_path))

    criterion = nn.CrossEntropyLoss()

    ff = dataset.features
    labels = dataset.labels
    flipped_train = attack_rd_flip(ff, ind*10000, (ind+1)*10000, labels, criterion, num_run=num_run, loss_thres=loss_thres, gamma=gamma)
    flipped_test = attack_rd_flip(ff, 50000, 60000, labels, criterion, num_run=num_run, loss_thres=loss_thres, gamma=gamma)

    save_dir = args.save_dir
    if not os.path.exists(save_dir):
        os.mkdir(save_dir)

    file_name = args.file_name
    with open(os.path.join(save_dir, file_name),'wb') as fp:
        pickle.dump((flipped_train, flipped_test), fp)
    
