import os
import random
import shutil
import time
import warnings
import sys
import numpy as np
import logging
import matplotlib.pyplot as plt
import cv2
import configparser

from PIL import Image
from alexnet_fc7out import alexnet, NormalizeByChannelMeanStd, resnet18
from dataset import PoisonGenerationDataset

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.utils.data
import torchvision.transforms as transforms

config = configparser.ConfigParser()
config.read(sys.argv[1])

experimentID = config["experiment"]["ID"]

options = config["poison_generation"]
data_root    = options["data_root"]
txt_root    = options["txt_root"]
arch        = options["arch"]
seed        = None
gpu         = int(options["gpu"])
epochs      = int(options["epochs"])
patch_size  = int(options["patch_size"])
eps         = int(options["eps"])
lr          = float(options["lr"])
rand_loc    = options.getboolean("rand_loc")
trigger_id  = int(options["trigger_id"])
num_iter    = int(options["num_iter"])
num_poison = int(options["num_poison"])
logfile     = options["logfile"].format(experimentID, arch, eps, rand_loc, patch_size, num_poison, trigger_id)
target_wnid = options["target_wnid"]
source_wnid_list = options["source_wnid_list"].format(experimentID)
num_source = int(options["num_source"])

saveDir_poison = "poison_data/" + experimentID + "/" + arch + "/rand_loc_" +  str(rand_loc) + '/eps_' + str(eps) + \
                    '/patch_size_' + str(patch_size) + '/trigger_' + str(trigger_id)
saveDir_patched = "patched_data/" + experimentID + "/" + arch + "/rand_loc_" +  str(rand_loc) + '/eps_' + str(eps) + \
                    '/patch_size_' + str(patch_size) + '/trigger_' + str(trigger_id)

if not os.path.exists(saveDir_poison):
    os.makedirs(saveDir_poison)
if not os.path.exists(saveDir_patched):
    os.makedirs(saveDir_patched)

if not os.path.exists("data/{}".format(experimentID)):
    os.makedirs("data/{}".format(experimentID))

def main():
    #logging
    if not os.path.exists(os.path.dirname(logfile)):
            os.makedirs(os.path.dirname(logfile))

    logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s %(message)s",
    handlers=[
        logging.FileHandler(logfile, "w"),
        logging.StreamHandler()
    ])

    logging.info("Experiment ID: {}".format(experimentID))

    if seed is not None:
        random.seed(seed)
        torch.manual_seed(seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    main_worker()


def main_worker():
    global best_acc1

    if gpu is not None:
        logging.info("Use GPU: {} for training".format(gpu))

    # create model
    logging.info("=> using pre-trained model '{}'".format("alexnet"))
    normalize = NormalizeByChannelMeanStd(
    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    if arch == "alexnet":
        model = alexnet(pretrained=True)
    elif arch == "resnet":
        model = resnet18(pretrained=True)

    model.eval()
    model = nn.Sequential(normalize, model)

    model = model.cuda(gpu)

    for epoch in range(epochs):
        # run one epoch
        train(model, epoch)

def save_image(img, fname):
    img = img.data.numpy()
    img = np.transpose(img, (1, 2, 0))
    img = img[: , :, ::-1]
    cv2.imwrite(fname, np.uint8(255 * img), [cv2.IMWRITE_PNG_COMPRESSION, 0])

def train(model, epoch):
    since = time.time()
    # AVERAGE METER
    losses = AverageMeter()

    # TRIGGER PARAMETERS
    trans_image = transforms.Compose([transforms.Resize((224, 224)),
                                      transforms.ToTensor(),
                                      ])
    trans_trigger = transforms.Compose([transforms.Resize((patch_size, patch_size)),
                                        transforms.ToTensor(),
                                        ])

    # PERTURBATION PARAMETERS
    eps1 = (eps/255.0)
    lr1 = lr

    trigger = Image.open('data/triggers/trigger_{}.png'.format(trigger_id)).convert('RGB')
    trigger = trans_trigger(trigger).unsqueeze(0).cuda(gpu)

    # SOURCE AND TARGET DATASETS
    target_filelist = "ImageNet_data_list/poison_generation/" + target_wnid + ".txt"

    # Use source wnid list
    if num_source==1:
        logging.info("Using single source for this experiment.")
    else:
        logging.info("Using multiple source for this experiment.")

    with open("data/{}/multi_source_filelist.txt".format(experimentID),"w") as f1:
        with open(source_wnid_list) as f2:
            source_wnids = f2.readlines()
            source_wnids = [s.strip() for s in source_wnids]

            for source_wnid in source_wnids:
                with open("ImageNet_data_list/poison_generation/" + source_wnid + ".txt", "r") as f2:
                    shutil.copyfileobj(f2, f1)

    source_filelist = "data/{}/multi_source_filelist.txt".format(experimentID)


    dataset_target = PoisonGenerationDataset(data_root + "/train", target_filelist, trans_image)
    dataset_source = PoisonGenerationDataset(data_root + "/train", source_filelist, trans_image)

    # SOURCE AND TARGET DATALOADERS
    train_loader_target = torch.utils.data.DataLoader(dataset_target,
                                                    batch_size=100,
                                                    shuffle=True,
                                                    num_workers=8,
                                                    pin_memory=True)

    train_loader_source = torch.utils.data.DataLoader(dataset_source,
                                                      batch_size=100,
                                                      shuffle=True,
                                                      num_workers=8,
                                                      pin_memory=True)


    logging.info("Number of target images:{}".format(len(dataset_target)))
    logging.info("Number of source images:{}".format(len(dataset_source)))

    # USE ITERATORS ON DATALOADERS TO HAVE DISTINCT PAIRING EACH TIME
    iter_target = iter(train_loader_target)
    iter_source = iter(train_loader_source)

    num_poisoned = 0
    for i in range(len(train_loader_target)):

        # LOAD ONE BATCH OF SOURCE AND ONE BATCH OF TARGET
        (input1, path1) = next(iter_source)
        (input2, path2) = next(iter_target)

        img_ctr = 0

        input1 = input1.cuda(gpu)
        input2 = input2.cuda(gpu)
        pert = nn.Parameter(torch.zeros_like(input2, requires_grad=True).cuda(gpu))

        for z in range(input1.size(0)):
            if not rand_loc:
                start_x = 224-patch_size-5
                start_y = 224-patch_size-5
            else:
                start_x = random.randint(0, 224-patch_size-1)
                start_y = random.randint(0, 224-patch_size-1)

            # PASTE TRIGGER ON SOURCE IMAGES
            input1[z, :, start_y:start_y+patch_size, start_x:start_x+patch_size] = trigger

        output1, feat1 = model(input1)
        feat1 = feat1.detach().clone()

        for k in range(input1.size(0)):
            img_ctr = img_ctr+1

            fname = saveDir_patched + '/' + 'htba_' + str(os.path.basename(path1[k])).split('.')[0] + '_' + 'epoch_' + str(epoch).zfill(2)\
                    + str(img_ctr).zfill(5)+'.png'

            save_image(input1[k].clone().cpu(), fname)
            num_poisoned +=1

        for j in range(num_iter):
            lr1 = adjust_learning_rate(lr, j)

            output2, feat2 = model(input2+pert)

            # FIND CLOSEST PAIR WITHOUT REPLACEMENT
            feat11 = feat1.clone()
            dist = torch.cdist(feat1, feat2)
            for _ in range(feat2.size(0)):
                dist_min_index = (dist == torch.min(dist)).nonzero().squeeze()
                feat1[dist_min_index[1]] = feat11[dist_min_index[0]]
                dist[dist_min_index[0], dist_min_index[1]] = 1e5

            loss1 = ((feat1-feat2)**2).sum(dim=1)
            loss = loss1.sum()

            losses.update(loss.item(), input1.size(0))

            loss.backward()

            pert = pert- lr1*pert.grad
            pert = torch.clamp(pert, -eps1, eps1).detach_()

            pert = pert + input2

            pert = pert.clamp(0, 1)

            if j%100 == 0:
                logging.info("Epoch: {:2d} | i: {} | iter: {:5d} | LR: {:2.4f} | Loss Val: {:5.3f} | Loss Avg: {:5.3f}"
                             .format(epoch, i, j, lr1, losses.val, losses.avg))

            if loss1.max().item() < 10 or j == (num_iter-1):
                for k in range(input2.size(0)):
                    img_ctr = img_ctr+1
                    input2_pert = (pert[k].clone().cpu())

                    fname = saveDir_poison + '/' + 'loss_' + str(int(loss1[k].item())).zfill(5) + '_' + 'epoch_' + \
                            str(epoch).zfill(2) + '_' + str(os.path.basename(path2[k])).split('.')[0] + '_' + \
                            str(os.path.basename(path1[k])).split('.')[0] + '_kk_' + str(img_ctr).zfill(5)+'.png'

                    save_image(input2_pert, fname)
                    num_poisoned +=1

                break

            pert = pert - input2
            pert.requires_grad = True

    time_elapsed = time.time() - since
    logging.info('Training complete one epoch in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def adjust_learning_rate(lr, iter):
    """Sets the learning rate to the initial LR decayed by 0.5 every 1000 iterations"""
    lr = lr * (0.5 ** (iter // 1000))
    return lr

if __name__ == '__main__':
    main()