"""
Run the Smart Label Flipping attack on a given model and dataset.
"""
import os
import sys
import time
import torch
import random
import argparse
import logging
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

from tqdm import tqdm
from torchvision import datasets as dset
from torchvision import transforms
from torch.utils.data import DataLoader

from models.resnet import resnet18

CIFAR_MEAN = [0.4914, 0.4822, 0.4465]
CIFAR_STD = [0.2471, 0.2435, 0.2616]

def parse_args():
    parser = argparse.ArgumentParser(description='Construct poisoned training data for the given network and dataset')
    parser.add_argument('--model_name', default='resnet18', type=str, help='Model name')
    parser.add_argument('--weights_path', default='./models/state_dicts/resnet18.pt', type=str, help='Path to model state dictionary')
    parser.add_argument('--data_dir', default='../data', type=str, help='Data directory')
    parser.add_argument('--device', default='cuda', type=str, help='Device to use for training')
    parser.add_argument('--seed', default=42, type=int, help='Random seed')
    parser.add_argument('--p_ratio', default=0.01, type=float, help='Fraction of training data that is poisoned')
    parser.add_argument('--batch_size', default=256, type=int, help='Batch size')
    parser.add_argument('--save_dir', default='./poisons/clf/', type=str, help='save path')

    return parser.parse_args()

def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in tqdm(test_loader, desc="Testing", total=len(test_loader)):
            criterion = nn.CrossEntropyLoss()
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

@torch.no_grad()
def get_training_logits(model, device, test_loader):
    model.eval()
    logits_and_labels = []
    running_idx = 0
    for data, targets in tqdm(test_loader, desc="Getting training logits", total=len(test_loader)):
        data, targets = data.to(device), targets.to(device)
        outputs = model(data)

        indices = [i for i in range(running_idx, running_idx + len(targets))]
        running_idx += len(targets)

        for output, target, index in zip(outputs, targets, indices):
            logits_and_labels.append((output.cpu().numpy(), target.item(), index))
    
    return logits_and_labels

def main():
    args = parse_args()
    torch.manual_seed(args.seed)
    random.seed(args.seed)

    # Load model and clean gradients
    model = resnet18(pretrained=False).to(args.device)
    model.load_state_dict(torch.load(args.weights_path))

    # Load dataset and select poisons
    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
    ])
    train_dataset = dset.CIFAR10(root=args.data_dir, train=True, download=True, transform=train_transform)
    train_loader = DataLoader(train_dataset, batch_size=256, shuffle=False, num_workers=4)  # dont shuffle, we need to track indices

    # get logits from entire training set
    logits_and_labels = get_training_logits(model, args.device, train_loader)
    
    # sort by largest logit
    logits_and_labels.sort(key=lambda x: x[0].max(), reverse=True)

    n_poisons = int(args.p_ratio * len(train_dataset))
    indices = [idx for _, _, idx in logits_and_labels[:n_poisons]]
    poisoned_labels = []
    for logits, label, index in logits_and_labels[:n_poisons]:
        assert label == train_dataset.targets[index]
        new_label = np.argmin(logits)
        poisoned_labels.append(new_label)
    
    # save poisoned indices
    save_str = f"clf-resnet18-cifar10-{args.p_ratio * 100:.1f}%.pth"
    save_path = os.path.join(args.save_dir, save_str)
    print("Saving poisons to {}".format(save_path))
    torch.save({
        "indices": indices,
        "poisoned_labels": poisoned_labels
    }, save_path)


if __name__ == '__main__':
    main()