import argparse
import os
import numpy as np
import random

import torchvision
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, DistributedSampler
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch import distributed as dist
import time

from generators import *
from gaussian_smoothing import *

from utils import *
import dataset
from dataset import CraftedTarSamples, SamplesFromImNames, SamplesFrom50000TrainData
from craft_samples import craft_samples


#################
torch.manual_seed(0)
torch.cuda.manual_seed(0)
torch.cuda.manual_seed_all(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(0)
##################






def parse_args():
    parser = argparse.ArgumentParser(description='ProposedTransferable Targeted Attack')
    parser.add_argument('--src_dir', default='path to source samples', help='Source Domain: natural images, paintings, medical scans, etc')
    parser.add_argument('--match_target', type=int, default=600, 
                        help='Target Domain (23, 54, 60, 124, 344, 443, 465, 600, 642, 744, 769, 885) samples')
    parser.add_argument('--match_dir', default= 'path to target samples for BAT_BS and BAT_CS', 
                        help='Path to data folder with target domain samples')
    parser.add_argument('--batch_size', type=int, default=16, 
                        help='Number of trainig samples/batch')
    parser.add_argument('--epochs', type=int, default=20, 
                        help='Number of training epochs')
    parser.add_argument('--lr', type=float, default=0.0002, 
                        help='Initial learning rate for adam')
    parser.add_argument('--eps', type=int, default=10, 
                        help='Perturbation Budget during training, eps')
    parser.add_argument('--gs', action='store_true', 
                        help='Apply gaussian smoothing')
    parser.add_argument('--save_dir', type=str, default='pretrained_generators', 
                        help='Directory to save generators')
    parser.add_argument('--gamma', type=float, default=1.5, 
                        help='multiplier with the similarity loss')
    parser.add_argument('--attack_type', type=str, default='BAT_CS',
                        help='BAT_BS, BAT_CS, BAT_CN')
    parser.add_argument('--TarSamNum', type=int, default=300, 
                        help='number of samples from the target class')
    parser.add_argument('--surr_name', type=str, default='resnet50',
                        help='Surrogate model: resnet50, densenet121')
    parser.add_argument('--K', type=int, default=5,
                        help='Number of discriminators derived from surrogate model')
    args = parser.parse_args()
    args.gs = True
    return args
args = parse_args()



def normalize(_t, mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225]):
    t = _t + 0
    t[:, 0, :, :] = (t[:, 0, :, :] - mean[0]) / std[0]
    t[:, 1, :, :] = (t[:, 1, :, :] - mean[1]) / std[1]
    t[:, 2, :, :] = (t[:, 2, :, :] - mean[2]) / std[2]
    return t


def reduce_mean(tensor, nprocs):
        rt = tensor.clone()
        dist.all_reduce(rt, op=dist.ReduceOp.SUM)
        rt /= nprocs
        return rt


def setup():
    dist.init_process_group(backend="nccl")
    torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))

def cleanup():
    dist.destroy_process_group()



def train_generator(args):
    print(args)
    setup()
    gpu_id = int(os.environ["LOCAL_RANK"])

    discriminators_name = [args.surr_name] + [args.surr_name+f'_pruned{k+1}' for k in range(args.K-1)]
    print(discriminators_name)
    
    args.save_dir = os.path.join(args.save_dir, args.surr_name)
    if not os.path.isdir(args.save_dir):
        if gpu_id == 0:
            os.makedirs(args.save_dir)
    
    if args.surr_name == 'resnet50':
        target_layer = 'layer3'
        mod_name = 'resnet50'
    elif args.surr_name == 'densenet121':
        target_layer = 'denseblock3'
        mod_name = 'densenet121'
    eps = args.eps / 255

    ##################### GPU infomation ####################################
    device = torch.device(f'cuda:{gpu_id}')
    
    world_size = torch.distributed.get_world_size()
    total_batch_size = world_size*args.batch_size
    if gpu_id == 0:
        print('total_batch_size', total_batch_size)
    
    # Discriminator
    all_discriminator = []
    for name in discriminators_name:
        model = get_model(name, args.surr_name)
        model = model.cuda(gpu_id)
        model.eval()
        all_discriminator.append(model)
    
    
    # Generator
    netG = GeneratorResnet()
    # netG = nn.DataParallel(netG)
    netG.cuda(gpu_id)
    
    # Optimizer
    optimG = optim.Adam(netG.parameters(), lr=args.lr, betas=(0.5, 0.999))

    netG = torch.nn.SyncBatchNorm.convert_sync_batchnorm(netG)
    netG = torch.nn.parallel.DistributedDataParallel(
        netG,
        device_ids=[gpu_id],
        output_device=gpu_id,
        find_unused_parameters=False,
        broadcast_buffers=False
    )
    
    
    # Data
    file_path = os.path.join(args.src_dir, 'train50000_samples_labels.txt')
    im_info = open(file_path, 'r').read().split('\n')
    # random.shuffle(im_info)
    
    train_data = SamplesFrom50000TrainData(os.path.join(args.src_dir, 'Images'), im_info)
    train_im_path = []
    train_im_lbl = []
    for idx in range(50000):
        if train_data.labels[idx]!=args.match_target:
            train_im_path.append(train_data.image_paths[idx])
            train_im_lbl.append(train_data.labels[idx])
    train_data.labels = train_im_lbl
    train_data.image_paths = train_im_path
    train_sampler = DistributedSampler(train_data)
    train_loader = DataLoader(train_data, batch_size=args.batch_size, 
                              shuffle=False, num_workers=4,
                              pin_memory=True, sampler=train_sampler,
                              drop_last=True)
    train_size = len(train_data)
    print('Training data size:', train_size)
    
    
    if args.attack_type in ['BAT_BS', 'BAT_CS']:
        tar_sam_name = f'HighConfTarSortedSamples_{args.match_target}_{args.surr_name}_K{args.K}.txt'
        meta_data_path = f'./meta_data/{args.surr_name}'
        os.makedirs(meta_data_path, exist_ok=True)
        file_path = os.path.join(meta_data_path, tar_sam_name)
        if not os.path.exists(file_path):
            print('saving images names...')
            save_high_confidence_tar_samples_name(file_path, all_discriminator, args, gpu_id=gpu_id)
        assert os.path.exists(file_path)

        # if args.attack_type == 'BAT_BS':
        im_names = open(file_path, 'r').read().split('\n')
        if gpu_id == 0:
            print(f'Image name path: {file_path}')
            print('first ten image names:')
            print(im_names[:5])
            print('total samples number', len(im_names))
        im_names = im_names[:args.TarSamNum]
        # random.shuffle(im_names)
        print('number of best target samples', len(im_names))
        # print(im_names)

        folders = os.listdir(args.match_dir)
        desired_folder = [folder for folder in folders if f'classID{args.match_target}_' in folder][0]
        data_path = os.path.join(args.match_dir, desired_folder)
        print(f'Image folder path: {data_path}')
        tar_data = SamplesFromImNames(data_path, im_names)


        if args.attack_type == 'BAT_CS':
            sample_dir = f'./craftedSamples/{args.attack_type}_{args.surr_name }_K{args.K}'
            file_name = f'{sample_dir}/TarCls_{args.match_target}_samplesNum{args.TarSamNum}.pt'
            if not os.path.exists(file_name):
                tar_sampler = DistributedSampler(tar_data, num_replicas=dist.get_world_size(), rank=dist.get_rank())
                dataloader = torch.utils.data.DataLoader(tar_data, batch_size=5,sampler=tar_sampler)
                craft_samples(args, all_discriminator, sample_dir, dataloader=dataloader, gpu_id=gpu_id, iterations=25, world_size=world_size)
            tar_data = CraftedTarSamples(file_name)


    if args.attack_type == 'BAT_CN':
        sample_dir = f'./craftedSamples/{args.attack_type}_{args.surr_name }_K{args.K}'
        file_name = f'{sample_dir}/TarCls_{args.match_target}_samplesNum{args.TarSamNum}.pt'
        if not os.path.exists(file_name):
            craft_samples(args, all_discriminator, sample_dir, dataloader=None, gpu_id=gpu_id, iterations=25, world_size=world_size)
        tar_data = CraftedTarSamples(file_name)


    
    
    tar_sampler = DistributedSampler(tar_data)
    train_loader_match = torch.utils.data.DataLoader(tar_data, batch_size=args.batch_size, 
                                                     shuffle=False, num_workers=4,
                                                    pin_memory=True, sampler=tar_sampler, 
                                                    drop_last=True)
    
    train_size_match = len(tar_data)
    print('Training (Match) data size:', train_size_match)
    # Iterator
    dataiter = iter(train_loader_match)
    
    
    if args.gs:
        kernel_size = 3
        pad = 2
        sigma = 1
        kernel = get_gaussian_kernel(kernel_size=kernel_size, pad=pad, sigma=sigma).cuda(gpu_id)
    


    def get_loss(discriminator, adv, adv_rot, adv_aug, img_match, feat_extract=None):
        if feat_extract == None:
            adv_out = discriminator(normalize(adv))
            adv_rot_out = discriminator(normalize(adv_rot))
            adv_aug_out = discriminator(normalize(adv_aug))
            img_match_out = discriminator(normalize(img_match))
            simi_loss = 0


        else:
            feat_extract.activation = None
            hook = feat_extract.get_activation(target_layer)
            adv_out = discriminator(normalize(adv))
            hook.remove()
            adv_features = feat_extract.activation
            
            feat_extract.activation = None
            hook = feat_extract.get_activation(target_layer)
            adv_rot_out = discriminator(normalize(adv_rot))
            hook.remove()
            adv_rot_features = feat_extract.activation
            
            feat_extract.activation = None
            hook = feat_extract.get_activation(target_layer)
            adv_aug_out = discriminator(normalize(adv_aug))
            hook.remove()
            adv_aug_features = feat_extract.activation
            
            feat_extract.activation = None
            hook = feat_extract.get_activation(target_layer)
            img_match_out = discriminator(normalize(img_match))
            hook.remove()
            img_match_features = feat_extract.activation
            
            
            # Feature similarity loss
            similarity_losses = []
            for input_feat in [adv_features, adv_rot_features, adv_aug_features]:
                similarity_losses.append(nn.functional.cosine_similarity(
                    input_feat.view(args.batch_size, -1), img_match_features.view(args.batch_size, -1)))
            simi_loss = torch.mean(sum(similarity_losses))

        
        # Loss
        loss_kl = 0.0
        for out in [adv_out, adv_rot_out, adv_aug_out]:

            loss_kl += (1.0 / args.batch_size) * criterion_kl(F.log_softmax(out, dim=1),
                                                                F.softmax(img_match_out, dim=1))
            loss_kl += (1.0 / args.batch_size) * criterion_kl(F.log_softmax(img_match_out, dim=1),
                                                                F.softmax(out, dim=1))

        loss_f = loss_kl - args.gamma*simi_loss
        return loss_f
    

    
    criterion_kl = nn.KLDivLoss(size_average=False)
    tt = time.time()
    for epoch in range(args.epochs):
        t1 = time.time()
        running_loss = 0
        for i, (imgs, aug_imgs, labels) in enumerate(train_loader):
            img = imgs.cuda(gpu_id, non_blocking=True)
            img_rot = rotation(imgs)[0].cuda(gpu_id, non_blocking=True)
            img_aug = aug_imgs.cuda(gpu_id, non_blocking=True)
    
            try:
                img_match = next(dataiter)
            except StopIteration:
                dataiter = iter(train_loader_match)
                img_match = next(dataiter)
            img_match = img_match[:img.shape[0]].cuda(gpu_id, non_blocking=True)
    
            netG.train()
            optimG.zero_grad()
    
            # Unconstrained Adversaries
            adv = netG(img)
            adv_rot = netG(img_rot)
            adv_aug = netG(img_aug)
    
            # Smoothing
            if args.gs:
                adv = kernel(adv)
                adv_rot = kernel(adv_rot)
                adv_aug = kernel(adv_aug)
    
    
            # Projection
            adv = torch.min(torch.max(adv, img - eps), img + eps)
            adv = torch.clamp(adv, 0.0, 1.0)
            adv_rot = torch.min(torch.max(adv_rot, img_rot - eps), img_rot + eps)
            adv_rot = torch.clamp(adv_rot, 0.0, 1.0)
            adv_aug = torch.min(torch.max(adv_aug, img_aug - eps), img_aug + eps)
            adv_aug = torch.clamp(adv_aug, 0.0, 1.0)
            
    
    
            # Loss
            loss =0
            for discriminator in all_discriminator:
                if 'ens_reg_resnets' == args.surr_name or '3RN2DN' == args.surr_name:
                    loss += get_loss(discriminator, adv, adv_rot, adv_aug, img_match) 
                else:
                    feat_extract = FeatureExtractor(discriminator, mod_name)
                    loss += get_loss(discriminator, adv, adv_rot, adv_aug, img_match, feat_extract=feat_extract) 

            loss = loss/len(all_discriminator)

            loss = reduce_mean(loss, dist.get_world_size()) 
            
            loss.backward()
            optimG.step()
            running_loss += loss.item()
    
            div = 100
            if i % div == 0 and dist.get_rank() == 0:
                print('Epoch: {0} \t Batch: {1} \t loss: {2:.5f} \t time: {3:3f}'.format(epoch, i, running_loss / div, time.time()-t1))
                running_loss = 0
    
        if gpu_id == 0:
            generator_name = lambda epch: '/{}_gamma{}_epch{}_tarClass{}_Surr_{}_K{}_SampNum{}.pth'.format(args.attack_type, args.gamma, epch, args.match_target, args.surr_name, args.K, args.TarSamNum)
            torch.save(netG.module.state_dict(),args.save_dir + generator_name(epoch))
            print('Total required time:', time.time()-tt)
            if epoch>0:
                remove_path = args.save_dir + generator_name(epoch-1)
                if os.path.exists(remove_path):
                    os.remove(remove_path)



if __name__ == '__main__':
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    args = parse_args()
    train_generator(args)