from __future__ import print_function
from pycocotools.coco import COCO
import numpy as np
import random
import os
import cv2

import sys
sys.path.append('./')

### For visualizing the outputs ###
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import io, transforms, utils
import torchvision.transforms.functional as TF
from tqdm.auto import tqdm
from torchvision.io.image import read_image
from torchvision.models.segmentation import lraspp_mobilenet_v3_large, LRASPP_MobileNet_V3_Large_Weights
from torchvision.transforms.functional import to_pil_image
from pathlib import Path
import random
from typing import Any, Callable, List, Optional, Tuple
import torch.optim as optim
from torch import nn, einsum
from torch.autograd import Variable

import torch
import torch.utils.data

from Guided.dataset.helpers import get_splitted_dataset
from Guided.helpers import get_parser, Operation, OptimizerDetails
from Guided.models.resnet import ResNet18_64x64, ResNet18_64x64_1, ResNet18_256x256
from scripts.imagenet import get_loader_from_dataset, get_train_val_datasets
import cv2

import torchvision
import cv2
import torchvision.models as models
from torchvision import transforms, utils
from torch.utils import data
import torch.nn.functional as F
import os
import errno
import shutil



import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from PIL import Image
import matplotlib.pyplot as plt

import torchvision.transforms as transforms

def create_folder(path):
    try:
        os.mkdir(path)
    except OSError as exc:
        if exc.errno != errno.EEXIST:
            raise
        pass

torch.manual_seed(0)

parser = get_parser()
parser.add_argument('--root', default='ICML')
parser.add_argument("--lr", default=1e-4, type=float)
parser.add_argument("--momentum", default=0.9, type=float)
parser.add_argument("--wd", default=1e-2, type=float)
parser.add_argument("--shuffle", default=False, help='shuffles the data when we can the train and val data')
parser.add_argument('--direct', action='store_true', help='use direct sampling for noising and denoising')
parser.add_argument('--epochs', default=100, type=int)
parser.add_argument('--run_command', default='PYTHONPATH=. python Guided/membership_classification.py',
                    help='How to run the script.')
parser.add_argument('--save_every', type=int, default=1)
parser.add_argument('--test_every', type=int, default=10)
parser.add_argument('--optimizer', default='adamw', choices=['sgd', 'adamw'])
parser.add_argument('--use_noise', action='store_true')
parser.add_argument('--fixed_noise', action='store_true')
parser.add_argument('--almost_fixed_noise', action='store_true')
parser.add_argument('--distribution', action='store_true')
parser.add_argument('--load', action='store_true')
parser.add_argument('--remove_bn', action='store_true', default=False)
parser.add_argument('--repeat', type=int, default=1)
parser.add_argument('--use_image', type=int, default=1)
parser.add_argument('--wandb', type=int, default=1)
parser.add_argument('--input_size', type=int, default=64)

parser.add_argument("--optim_lr", default=1e-3, type=float)
parser.add_argument('--optim_max_iters', type=int, default=1)
parser.add_argument("--optim_loss_cutoff", default=0.00001, type=float)
parser.add_argument('--optim_guidance_3', action='store_true', default=False)
parser.add_argument('--optim_original_guidance', action='store_true', default=False)
parser.add_argument("--optim_guidance_3_wt", default=2.0, type=float)
parser.add_argument('--optim_warm_start', action='store_true', default=False)
parser.add_argument('--optim_print', action='store_true', default=False)
parser.add_argument('--optim_do_guidance_3_norm', action='store_true', default=False)
parser.add_argument('--optim_aug', action='store_true', default=False)
parser.add_argument('--optim_folder', default='./temp/')
parser.add_argument("--optim_num_steps", nargs="+", default=[1], type=int)
parser.add_argument('--optim_sampling_type', default=None)
parser.add_argument("--optim_mask_fraction", default=0.5, type=float)


args = parser.parse_args()

# MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond False --diffusion_steps 1000 --image_size 256 --learn_sigma True --noise_schedule linear --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 False --use_scale_shift_norm True"
# PYTHONPATH=. python scripts/load_model.py $MODEL_FLAGS --classifier_scale 10.0 --classifier_path models/256x256_classifier.pt --model_path models/256x256_diffusion_uncond.pt $SAMPLE_FLAGS

# MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond False --diffusion_steps 1000 --image_size 256 --learn_sigma True --noise_schedule linear --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 False --use_scale_shift_norm True"
# PYTHONPATH=. python Guided/Segmentation_mobilenet.py $MODEL_FLAGS --classifier_scale 10.0 --classifier_path models/256x256_classifier.pt --model_path models/256x256_diffusion_uncond.pt $SAMPLE_FLAGS --batch_size 4

BATCH_SIZE = args.batch_size


def gram_matrix(input):
    a, b, c, d = input.size()
    features = input.view(a * b, c * d)
    G = torch.mm(features, features.t())
    return G.div(a * b * c * d)



class StyleLoss(nn.Module):

    def __init__(self, target_feature):
        super(StyleLoss, self).__init__()
        self.target = gram_matrix(target_feature).detach()

    def forward(self, input):
        G = gram_matrix(input)
        self.loss = F.mse_loss(G, self.target)
        return input


cnn = models.vgg19(pretrained=True).features.cuda()
cnn = cnn.eval()


class Normalization(nn.Module):
    def __init__(self, Trans):
        super(Normalization, self).__init__()
        self.trans = Trans

    def forward(self, img):
        img = (img + 1) * 0.5
        img = self.trans(img)
        return img


def get_style_model_and_losses(cnn, style_img):

    Trans = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    style_layers = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']

    normalization = Normalization(Trans).cuda()
    style_losses = []

    model = nn.Sequential(normalization)

    i = 0  # increment every time we see a conv
    for layer in cnn.children():
        if isinstance(layer, nn.Conv2d):
            i += 1
            name = 'conv_{}'.format(i)
        elif isinstance(layer, nn.ReLU):
            name = 'relu_{}'.format(i)
            layer = nn.ReLU(inplace=False)
        elif isinstance(layer, nn.MaxPool2d):
            name = 'pool_{}'.format(i)
        elif isinstance(layer, nn.BatchNorm2d):
            name = 'bn_{}'.format(i)
        else:
            raise RuntimeError('Unrecognized layer: {}'.format(layer.__class__.__name__))

        model.add_module(name, layer)

        if name in style_layers:
            # add style loss:
            target_feature = model(style_img).detach()
            style_loss = StyleLoss(target_feature)
            model.add_module("style_loss_{}".format(i), style_loss)
            style_losses.append(style_loss)

    # now we trim off the layers after the last content and style losses
    for i in range(len(model) - 1, -1, -1):
        if isinstance(model[i], StyleLoss):
            break
    model = model[:(i + 1)]

    return model, style_losses

# style transfer
loader = transforms.Compose([
    transforms.Resize(256),  # scale imported image
    transforms.ToTensor()])

style_img = Image.open("./picasso.jpg")
style_img = loader(style_img).unsqueeze(0)
style_img = 2 * style_img - 1
style_img = style_img.cuda()

new_style_img = style_img
for i in range(BATCH_SIZE - 1):
    new_style_img = torch.cat((new_style_img, style_img), dim=0)

style_img = new_style_img


results_folder = args.optim_folder
create_folder(results_folder)

class Final_loss(nn.Module):
    def __init__(self, cnn, style_img):
        super(Final_loss, self).__init__()
        self.model, self.style_losses = get_style_model_and_losses(cnn, style_img)

    def forward(self, img, dummy_img):
        self.model.requires_grad_(False)
        self.model(img)

        style_score = 0
        for sl in self.style_losses:
            style_score += 100 * sl.loss
        print(style_score)
        return style_score

fl = Final_loss(cnn, style_img)



operation = OptimizerDetails()

seq = []
pre = torch.nn.Sequential(*seq)

operation.num_steps = args.optim_num_steps #[2]
operation.operation_func = None
operation.optimizer = 'Adam'
operation.lr = args.optim_lr #0.01
operation.loss_func = fl
operation.max_iters = args.optim_max_iters #00
operation.loss_cutoff = args.optim_loss_cutoff #0.00001
operation.tv_loss = None
operation.guidance_3 = args.optim_guidance_3 #True
operation.original_guidance = args.optim_original_guidance
operation.optim_guidance_3_wt = args.optim_guidance_3_wt
operation.do_guidance_3_norm = args.optim_do_guidance_3_norm
operation.warm_start = args.optim_warm_start #False
operation.print = args.optim_print
operation.print_every = 10
operation.folder = results_folder
if args.optim_aug:
    operation.Aug = pre


operator = Operation(args, operation=operation, shape=[BATCH_SIZE, 3, 256, 256], progressive=True)
cnt = 0



def return_cv2(img, path):
    black = [255, 255, 255]
    img = (img + 1) * 0.5
    utils.save_image(img, path, nrow=1)
    img = cv2.imread(path)
    img = cv2.copyMakeBorder(img, 10, 10, 10, 10, cv2.BORDER_CONSTANT, value=black)
    return img

print('loading the dataset...')
train_dataset, val_dataset = get_train_val_datasets(args)
print('done')
print('splitting the dataset...')
train1, train2 = get_splitted_dataset(dataset=train_dataset,
                                      checkpoint_path='checkpoints/non_equal_split/partitions_train.pt')
val1, val2 = get_splitted_dataset(dataset=val_dataset, checkpoint_path='checkpoints/non_equal_split/partitions_val.pt')
print('done')
train1, train2 = get_loader_from_dataset(args, train1, True), get_loader_from_dataset(args, train2, False)
val1, val2 = get_loader_from_dataset(args, val1, True), get_loader_from_dataset(args, val2, False)


def get_images():
    take_labels = [i for i in range(153,260)]

    dog_images = []
    dog_labels = []

    for batch_ind, batch in enumerate(val1):
        image, label = batch
        for i in range(label.shape[0]):
            if label[i] in take_labels:
                dog_images.append(image[i:i+1])
                dog_labels.append(label[i:i+1])

                if len(dog_images) == BATCH_SIZE:
                    return dog_images, dog_labels


dog_images, dog_labels = get_images()
dog_images = torch.concat(dog_images, dim=0)
dog_labels = torch.concat(dog_labels, dim=0)



for batch_ind in range(1):
    image, label = dog_images, dog_labels
    image, label = image.cuda(), label.cuda()

    utils.save_image((image + 1) * 0.5, f'{results_folder}/og_img_{batch_ind}.png')


    # input_img.requires_grad_(True)
    # optimizer = torch.optim.Adam([input_img], lr=0.01)
    # style_weight = 100
    # content_weight = 0
    #
    # for i in range(100):
    #     with torch.no_grad():
    #         input_img.clamp_(-1, 1)
    #
    #     optimizer.zero_grad()
    #     loss = fl(input_img)
    #     loss.backward()
    #     print(loss)
    #     optimizer.step()
    #
    # utils.save_image((input_img + 1) * 0.5, f'{results_folder}/new_img_{batch_ind}.png')
    #
    # exit()

    utils.save_image((image + 1) * 0.5, f'{results_folder}/og_img_{batch_ind}.png')


    print("Start")
    output = operator.operator(label=label, operated_image=None)
    utils.save_image((output + 1) * 0.5, f'{results_folder}/new_img_{batch_ind}.png')

    if batch_ind == 0:
        break



