"""
Run the GC attack on a given model and dataset.
Parts of this implementation are adapted from 
https://github.com/JonasGeiping/poisoning-gradient-matching (update step)
and
https://github.com/watml/plim (gradient canceling attack).
"""
import os
import sys
import time
import torch
import random
import argparse
import logging
import torch.nn.functional as F

import torch.multiprocessing as mp
import torch.distributed as dist

from tqdm import tqdm
from datetime import timedelta
from torchvision import datasets as dset
from torchvision import transforms
from torch.utils.data import DataLoader, Subset

from victim_model import VictimModel

CIFAR_SIZE = 50000
CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124]
CIFAR_STD = [0.24703233, 0.24348505, 0.26158768]

def parse_args():
    parser = argparse.ArgumentParser(description='Construct poisoned training data for the given network and dataset')
    parser.add_argument('--model_name', default='pdarts', choices=["pdarts", 'resnet18', 'd-darts'], type=str, help='Model name')
    parser.add_argument('--data_dir', default='../data', type=str, help='Data directory')
    parser.add_argument('--device', default='cuda', type=str, help='Device to use for training')
    parser.add_argument('--seed', default=42, type=int, help='Random seed')
    parser.add_argument('--eps', default=16, type=float)
    parser.add_argument('--p_ratio', default=0.01, type=float, help='Fraction of training data that is poisoned')
    parser.add_argument('--save', default='./gc_runs/', type=str, help='Name tag for the result table and possibly for export folders.')
    parser.add_argument('--model_path', default=None, type=str)
    parser.add_argument('--clean_grads_path', default=None, type=str)

    parser.add_argument('--attack_iters', default=250, type=int)
    parser.add_argument('--lr', default=0.1, type=float)
    parser.add_argument('--target_criterion', default='cross-entropy', type=str, help='Loss criterion for target loss')
    parser.add_argument('--pbatch', default=512, type=int, help='Poison batch size during optimization')
    parser.add_argument('--param_type', type=str, choices=["weight", "arch", "all"], help='parameters to target from NAS model')

    parser.add_argument('--arch', type=str, default=None, help='Genotype of model for discretized DARTS architecture')

    return parser.parse_args()

def autograd(outputs, inputs, create_graph=False):
    """Compute gradient of outputs w.r.t. inputs, assuming outputs is a scalar."""
    grads = torch.autograd.grad(outputs, inputs, create_graph=create_graph, allow_unused=True)
    return [xx if xx is not None else yy.new_zeros(yy.size()) for xx, yy in zip(grads, inputs)]

def create_exp_dir(args):
    os.makedirs(args.save, exist_ok=False)
    log_format = '%(asctime)s %(message)s'
    logging.basicConfig(stream=sys.stdout, level=logging.INFO,
        format=log_format, datefmt='%m/%d %I:%M:%S %p')
    fh = logging.FileHandler(os.path.join(args.save, 'log.txt'))
    fh.setFormatter(logging.Formatter(log_format))
    logging.getLogger().addHandler(fh)
    logging.info("args = %s", args)

def setup(rank, world_size, backend='nccl'):
    dist.init_process_group(backend, rank=rank, world_size=world_size, timeout=timedelta(hours=1))
    torch.cuda.set_device(rank)

def cleanup():
    dist.destroy_process_group()

def init_queue(task_queue, poison_slices):
    if task_queue.empty():
        for slice in poison_slices:
            task_queue.put(slice)
    else:
        raise ValueError("Task queue is not empty")

def run_gc(rank, world_size, task_queue, return_dict, poison_slices, poison_bounds, poisoned_indices, args, lock):
    setup(rank, world_size)

    device = f'cuda:{rank}'
    num_batches = len(poison_slices)

    # Load model and clean gradients
    model = VictimModel(args.model_name, args.model_path, device, args.arch, args.param_type)
    clean_grads = torch.load(args.clean_grads_path, map_location='cpu')

    # create poisons mask
    poison_delta = torch.zeros((int(args.p_ratio * 50000), 3, 32, 32), device='cpu')
    poison_delta.grad = torch.zeros_like(poison_delta, device='cpu')
    
    # rank 0 holds variables used to update poisons
    if rank == 0:
        # poison_bounds = torch.stack([image for image, _ in poison_slices], dim=0)
        att_optimizer = torch.optim.Adam([poison_delta], lr=args.lr)

        # Data mean and std for clamping
        dm = torch.tensor(CIFAR_MEAN)[None, :, None, None].squeeze(0)
        ds = torch.tensor(CIFAR_STD)[None, :, None, None].squeeze(0)

        # best loss so far
        best_loss = float('inf')

    # run attack
    if rank == 0:
        pbar = tqdm(range(args.attack_iters), total=args.attack_iters, desc='Running GC', position=rank)
    else:
        pbar = range(args.attack_iters)

    for it in pbar:
        if rank == 0:
            init_queue(task_queue, poison_slices)
        
        dist.barrier()

        compute_updates(rank, task_queue, return_dict, model, clean_grads, poison_delta, device, args, lock)

        dist.barrier()

        with torch.no_grad():
            if rank == 0:
                # collate updates from all processes
                total_loss, total_cosim = 0, 0
                for _, (batch_slice, loss, cosim, update) in return_dict.items():
                    total_loss += loss
                    total_cosim += cosim

                    with torch.no_grad():
                        poison_delta.grad[batch_slice] = update
                return_dict.clear()
                
                total_loss /= num_batches
                total_cosim /= num_batches
                
                # update step
                att_optimizer.step()
                att_optimizer.zero_grad(set_to_none=False)

                # poison_delta = poison_delta - args.lr * poison_delta.grad

                # clamp in [-eps, eps]
                poison_delta.data = torch.clamp(
                    poison_delta.data,
                    -args.eps / ds / 255,
                    args.eps / ds / 255
                )

                # clamp in [0, 1] (normalized)
                clamped_poisons = torch.clamp(
                    poison_bounds + poison_delta.data,
                    -dm / ds,
                    (1 - dm) / ds
                )
                poison_delta.data = clamped_poisons - poison_bounds  # new delta

                # report every 10 batches and save poisons
                if it % 10 == 0 or it == args.attack_iters - 1:
                    logging.info(f"iter {it} loss: {total_loss:.6f} cosine similarity: {total_cosim:.6f}")

                    # Save best poisons so far
                    if total_loss < best_loss:
                        best_loss = total_loss

                        poisoned_images = poison_bounds + poison_delta
                        poisoned_images = poisoned_images * ds + dm
                        poisoned_images = [transforms.ToPILImage()(img) for img in poisoned_images]     # Convert to PIL Images

                        save_str = f"poisons.pth"
                        save_path = os.path.join(args.save, save_str)
                        logging.info("Saving poisons with loss={:.4f} to {}".format(best_loss, save_path))
                        torch.save({
                            "indices": poisoned_indices,
                            "poisoned_images": poisoned_images
                        }, save_path)
            
            # broadcast updated poisons to all processes
            poison_delta_copy = poison_delta.clone().to(device)
            dist.broadcast(poison_delta_copy, src=0)
            
            if rank != 0:
                poison_delta.data.copy_(poison_delta_copy.data)
                poison_delta.grad = torch.zeros_like(poison_delta)
    
    cleanup()
            

def compute_updates(rank, task_queue, return_dict, model, clean_grads, poison_delta, device, args, lock):
    criterion = torch.nn.CrossEntropyLoss(reduction='mean')

    while not task_queue.empty():
        try:
            # Get the next task from the queue
            batch_idx, curr_slice, (data_p, target) = task_queue.get_nowait()
            # logging.info(f"Rank {rank}: Processing batch {batch_idx}")
        except:
            continue
    
        data_p, target = data_p.to(device), target.to(device).long()

        poison_delta_slice = poison_delta[curr_slice].to(device)
        poison_delta_slice.requires_grad = True

        new_data_p = data_p + poison_delta_slice

        output_c = model(new_data_p)
        loss_c = criterion(output_c,target)

        # wrt to w here
        poisoned_grads = autograd(loss_c,tuple(model.target_parameters()),create_graph=True)
        
        # Concatenate all clean and poisoned gradients
        clean_grads_concat = torch.cat([g_c.view(-1).to(device) for g_c in clean_grads])
        poisoned_grads_concat = torch.cat([g_p.view(-1) for g_p in poisoned_grads])

        # compute loss for optimization problem. We want: (1-p) * clean_grads + p * poisoned_grads = 0
        grad_sum = (1 - args.p_ratio) * clean_grads_concat + args.p_ratio * poisoned_grads_concat
        loss = torch.norm(grad_sum, 2).square()

        # compute cosine similarity (sanity check; should be close to -1)
        with torch.no_grad():
            cosim = F.cosine_similarity(clean_grads_concat, poisoned_grads_concat, dim=0).item()

        update = autograd(loss, [poison_delta_slice], create_graph=False)[0].cpu()

        return_dict[batch_idx] = [curr_slice, loss.item(), cosim, update]

def main():
    args = parse_args()
    torch.manual_seed(args.seed)
    random.seed(args.seed)

    args.save = '{}gc-{}-{:.1f}%-{}'.format(
        args.save, args.model_name, 
        args.p_ratio * 100, time.strftime("%Y%m%d-%H%M%S")
    )
    create_exp_dir(args)

    # Load dataset and select poisons
    train_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(CIFAR_MEAN, CIFAR_STD)])
    train_dataset = dset.CIFAR10(root=args.data_dir, train=True, download=True, transform=train_transform)

    poisoned_indices = random.sample(range(len(train_dataset)), int(args.p_ratio * len(train_dataset)))
    poison_subset = Subset(train_dataset, poisoned_indices)                                           # have to init these manually     
    poison_bounds = torch.stack([image for image, _ in poison_subset], dim=0)

    # iterate through data beforehand so processes can quickly access slices
    poison_loader = DataLoader(poison_subset, batch_size=args.pbatch, shuffle=False)
    poison_slices = []
    running_idx = 0
    for batch_idx, data in enumerate(poison_loader):
        curr_slice = slice(running_idx, running_idx + len(data[0]))
        poison_slices.append((batch_idx, curr_slice, data))

        running_idx += len(data[0])
    
    # Setup distributed training
    world_size = torch.cuda.device_count()

    manager = mp.Manager()
    return_dict = manager.dict()                        # processes will put their computed grad slices here
    task_queue = manager.Queue()                                # processes will get slices of data from here
    lock = mp.Lock()

    processes = []
    for rank in range(world_size):
        p = mp.Process(target=run_gc, args=(
            rank, world_size,
            task_queue, return_dict,
            poison_slices, 
            poison_bounds, poisoned_indices,
            args, lock
        ))
        p.start()
        processes.append(p)
    
    for p in processes:
        p.join()

if __name__ == '__main__':
    main()



