from pycocotools.coco import COCO
import numpy as np
import random
import os
import cv2

### For visualizing the outputs ###
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import io, transforms, utils
import torchvision.transforms.functional as TF
from tqdm.auto import tqdm
from torchvision.io.image import read_image
from torchvision.models.segmentation import fcn_resnet50, FCN_ResNet50_Weights, deeplabv3_resnet50, DeepLabV3_ResNet50_Weights
from torchvision.transforms.functional import to_pil_image
from pathlib import Path
import random
from typing import Any, Callable, List, Optional, Tuple
import torch.optim as optim
from torch import nn, einsum
from torch.autograd import Variable

import torch
import torch.utils.data

from Guided.dataset.helpers import get_splitted_dataset
from Guided.helpers import get_model, get_parser, Operation, OptimizerDetails
from Guided.models.resnet import ResNet18_64x64, ResNet18_64x64_1, ResNet18_256x256
from scripts.imagenet import get_loader_from_dataset, get_train_val_datasets
import cv2

import torchvision
import cv2
from torchvision import transforms, utils
from torch.utils import data
import torch.nn.functional as F
import torchvision.models as models
import os
import errno
import shutil

def create_folder(path):
    try:
        os.mkdir(path)
    except OSError as exc:
        if exc.errno != errno.EEXIST:
            raise
        pass

torch.manual_seed(0)

parser = get_parser()
parser.add_argument('--root', default='ICML')
parser.add_argument("--lr", default=1e-4, type=float)
parser.add_argument("--momentum", default=0.9, type=float)
parser.add_argument("--wd", default=1e-2, type=float)
parser.add_argument("--shuffle", default=False, help='shuffles the data when we can the train and val data')
parser.add_argument('--direct', action='store_true', help='use direct sampling for noising and denoising')
parser.add_argument('--epochs', default=100, type=int)
parser.add_argument('--run_command', default='PYTHONPATH=. python Guided/membership_classification.py',
                    help='How to run the script.')
parser.add_argument('--save_every', type=int, default=1)
parser.add_argument('--test_every', type=int, default=10)
parser.add_argument('--optimizer', default='adamw', choices=['sgd', 'adamw'])
parser.add_argument('--use_noise', action='store_true')
parser.add_argument('--fixed_noise', action='store_true')
parser.add_argument('--almost_fixed_noise', action='store_true')
parser.add_argument('--distribution', action='store_true')
parser.add_argument('--load', action='store_true')
parser.add_argument('--remove_bn', action='store_true', default=False)
parser.add_argument('--repeat', type=int, default=1)
parser.add_argument('--use_image', type=int, default=1)
parser.add_argument('--wandb', type=int, default=1)
parser.add_argument('--input_size', type=int, default=64)

args = parser.parse_args()

# MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond False --diffusion_steps 1000 --image_size 256 --learn_sigma True --noise_schedule linear --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 False --use_scale_shift_norm True"
# PYTHONPATH=. python scripts/load_model.py $MODEL_FLAGS --classifier_scale 10.0 --classifier_path models/256x256_classifier.pt --model_path models/256x256_diffusion_uncond.pt $SAMPLE_FLAGS

# MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond True --diffusion_steps 1000 --image_size 256 --learn_sigma True --noise_schedule linear --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 False --use_scale_shift_norm True"
# PYTHONPATH=. python Guided/Adv_Resnet_attack.py $MODEL_FLAGS --classifier_scale 0.0 --classifier_path models/256x256_classifier.pt --model_path models/256x256_diffusion.pt $SAMPLE_FLAGS --batch_size 4



BATCH_SIZE = args.batch_size
resolution_fact = 8

model = models.resnet18(pretrained=True)
Trans = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

class Attack_Model(nn.Module):
    def __init__(self, model, Trans):
        super(Attack_Model, self).__init__()
        self.trans = Trans
        self.model = model


    def forward(self, x):
        x = (x + 1) * 0.5
        x = TF.resize(x, (224, 224), interpolation=TF.InterpolationMode.BILINEAR)
        x = self.trans(x)
        x = self.model(x)

        return x

operation_func = Attack_Model(model, Trans)
operation_func = torch.nn.DataParallel(operation_func).cuda()
operation_func.eval()

for param in operation_func.parameters():
    param.requires_grad = False

def mse_ce_loss(input, target):
    c = nn.CrossEntropyLoss(reduce=False)

    model = target[0]
    img = target[1]
    label = target[2]

    mse_l = ((input - img) ** 2).mean(dim=[1, 2, 3])
    ce_l = c(model(input), label)

    return mse_l + 0.1 * ce_l


# operation = [2, operation_func, optim.Adam, 0.008, weighted_ce_loss, 500, 0.01, 1]
# operation = [2, operation_func, optim.Adam, 0.01, mse_loss, 2000, 0.005, 1]

operation = OptimizerDetails()

operation.num_steps = [5]
operation.operation_func = None
operation.optimizer = 'Adam'
operation.lr = 0.1
operation.loss_func = mse_ce_loss
operation.max_iters = 1000
operation.loss_cutoff = 0.0001

operator = Operation(args, operation=operation, shape=[BATCH_SIZE, 3, 256, 256], progressive=True)
cnt = 0

def return_cv2(img, path):
    black = [255, 255, 255]
    img = (img + 1) * 0.5
    utils.save_image(img, path, nrow=1)
    img = cv2.imread(path)
    img = cv2.copyMakeBorder(img, 10, 10, 10, 10, cv2.BORDER_CONSTANT, value=black)
    return img

print('loading the dataset...')
train_dataset, val_dataset = get_train_val_datasets(args)
print('done')
print('splitting the dataset...')
train1, train2 = get_splitted_dataset(dataset=train_dataset,
                                      checkpoint_path='checkpoints/non_equal_split/partitions_train.pt')
val1, val2 = get_splitted_dataset(dataset=val_dataset, checkpoint_path='checkpoints/non_equal_split/partitions_val.pt')
print('done')
train1, train2 = get_loader_from_dataset(args, train1, True), get_loader_from_dataset(args, train2, False)
val1, val2 = get_loader_from_dataset(args, val1, True), get_loader_from_dataset(args, val2, False)


results_folder = './adv_resnet_attack/'
create_folder(results_folder)


for batch_ind, batch in enumerate(val1):
    image, label = batch
    image, label = image.cuda(), label.cuda()

    with torch.no_grad():
        map = [operation_func, image, torch.ones_like(label)]

    outputs = operation_func(image)
    _, max_indices = torch.max(outputs, 1)
    print(max_indices)

    utils.save_image((image + 1) * 0.5, f'{results_folder}/og_img_{batch_ind}.png')



    # operated_image = map
    # operation_func_1 = operation.operation_func
    # criterion = operation.loss_func
    # max_iters = operation.max_iters
    # loss_cutoff = operation.loss_cutoff
    #
    # x0 = torch.randn_like(image)
    # torch.set_grad_enabled(True)
    # x0 = Variable(x0, requires_grad=True)
    #
    # if operation.optimizer == 'Adam':
    #     lr = operation.lr
    #     optimizer = torch.optim.Adam([x0], lr=lr)
    #
    # if operation.lr_scheduler == 'CosineAnnealingLR':
    #     scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, max_iters)
    #
    # loss = None
    # _ = None
    # weights = torch.ones_like(x0).cuda()
    # ones = torch.ones_like(x0).cuda()
    # zeros = torch.zeros_like(x0).cuda()
    #
    # for _ in range(max_iters):
    #     with torch.no_grad():
    #         x0.clamp_(-1, 1)
    #     optimizer.zero_grad()
    #
    #
    #     if operation_func_1 != None:
    #         op_im = operation_func_1(x0)
    #     else:
    #         op_im = x0
    #
    #     loss = criterion(op_im, operated_image)
    #
    #     for __ in range(loss.shape[0]):
    #         if loss[__] < loss_cutoff:
    #             weights[__] = zeros[__]
    #         else:
    #             weights[__] = ones[__]
    #
    #     before_x = torch.clone(x0.data)
    #     print("Here ", _, loss)
    #
    #     m_loss = loss.mean()
    #     m_loss.backward()
    #     optimizer.step()
    #
    #     if operation.lr_scheduler != None:
    #         scheduler.step()
    #
    #     with torch.no_grad():
    #         x0.data = before_x * (1 - weights) + weights * x0.data





    output = operator.operator(label=label, operated_image=map)

    print("Actual label ", label)

    outputs = operation_func(image)
    max_vals, max_indices = torch.max(outputs, 1)
    print("The predicted label was ", max_indices)

    outputs = operation_func(output)
    max_vals, max_indices = torch.max(outputs, 1)
    print("The predicted label for new image is ", max_indices)

    mse_l = ((image - output) ** 2).mean(dim=[1, 2, 3])
    print("The mse loss is ", mse_l)

    output = (output + 1) * 0.5
    utils.save_image(output, f'{results_folder}/new_img_{batch_ind}.png')




    if batch_ind == 0:
        break




