import os
import math
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, DistributedSampler
import torch.distributed as dist
import time
from utils import normalize



def print_predictions(model, x):
    logits = model(normalize(x))
    predProb = torch.softmax(logits, dim=1)
    val, topk_comb = torch.topk(predProb, 5)
    print('combined image topK predicitons:', topk_comb)
    print('top K pred scores:', val.data)


def get_confident_samples(args, imgs, disciminators, loss_fn,
                          alpha, iterations=25, gpu_id=0):
    tar_lbl = torch.tensor([args.match_target]*imgs.shape[0]).cuda(gpu_id)
    perturbed = imgs.clone().cuda(gpu_id)
    delta = torch.zeros_like(perturbed, requires_grad=True).cuda(gpu_id)

    for iters in range(iterations):
        inputs = perturbed + delta
        inputs = inputs.clamp(0,1)

        tot_loss = 0
        for model in disciminators:
            logits = model(normalize(inputs))
            loss = loss_fn(logits, tar_lbl)
            tot_loss += loss
        tot_loss.backward()
        grad = delta.grad.clone()
        delta.grad.zero_()

        delta.data = delta.data - alpha*grad
        delta.data = ((inputs + delta.data).clamp(0,1)) - inputs
        desired_sam = (perturbed+delta)
    # print_predictions(model, desired_sam)
    return desired_sam.data


def craft_samples(args, disciminators, sample_dir, dataloader=None, gpu_id=0, iterations=25, world_size=1):
    print('Start crafting samples')
    
    os.makedirs(sample_dir, exist_ok=True)
    tt = time.time()

    if args.attack_type == 'BAT_CS':
        alpha = 0.25
        assert dataloader is not None
    if args.attack_type == 'BAT_CN':
        alpha = 1
        assert dataloader == None
    loss_fn = nn.CrossEntropyLoss()

    confident_samples = []
    if args.attack_type == 'BAT_CS':
        for i, imgs in enumerate(dataloader):
            # print(f'GPU Rank {dist.get_rank()} received data:', imgs[:,0,0,:5])
            conf_samples = get_confident_samples(args, imgs, 
                                    disciminators, loss_fn, alpha, 
                                    iterations=iterations, gpu_id=gpu_id)
            gather_list = [torch.zeros_like(conf_samples) for _ in range(world_size)]
            dist.all_gather(gather_list, conf_samples)
            confident_samples += gather_list
            # print(f'craft samples iter {i}: shape:', torch.cat(confident_samples, dim=0).shape)
    
    if args.attack_type == 'BAT_CN':
        batch_sz = 5
        iter_number = math.ceil(args.TarSamNum/(batch_sz*world_size))
        for i in range(iter_number):
            torch.manual_seed(gpu_id)
            imgs = torch.rand(batch_sz,3,224,224)
            conf_samples = get_confident_samples(args, imgs, 
                                    disciminators, loss_fn, alpha, 
                                    iterations=iterations, gpu_id=gpu_id)
            gather_list = [torch.zeros_like(conf_samples) for _ in range(world_size)]
            dist.all_gather(gather_list, conf_samples)
            confident_samples += gather_list
            # print(f'craft samples iter {i}: shape:', torch.cat(confident_samples, dim=0).shape)
    print('Time required to craft the samples:', time.time()-tt)
    torch.save(torch.cat(confident_samples, dim=0).cpu(), f'{sample_dir}/TarCls_{args.match_target}_samplesNum{args.TarSamNum}.pt')



    