from pycocotools.coco import COCO
from pycocotools import mask
import numpy as np
import random
import os
import cv2

import sys
sys.path.append('./')

### For visualizing the outputs ###
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import io, transforms, utils
import torchvision.transforms.functional as TF
from tqdm.auto import tqdm
from torchvision.io.image import read_image
from torchvision.models.segmentation import DeepLabV3_ResNet101_Weights, deeplabv3_resnet101, lraspp_mobilenet_v3_large, LRASPP_MobileNet_V3_Large_Weights
from torchvision.transforms.functional import to_pil_image
from pathlib import Path
import random
from typing import Any, Callable, List, Optional, Tuple
import torch.optim as optim
from torch import nn, einsum
from torch.autograd import Variable

from torchvision.utils import make_grid

import torch
import torch.utils.data

from Guided.dataset.helpers import get_splitted_dataset
from Guided.helpers import get_parser, Operation, OptimizerDetails
from Guided.models.resnet import ResNet18_64x64, ResNet18_64x64_1, ResNet18_256x256
from scripts.imagenet import get_loader_from_dataset, get_train_val_datasets
import cv2

import torchvision
import cv2
from torchvision import transforms, utils
from torch.utils import data
import torch.nn.functional as F
import os
import errno
import shutil

from torchvision.transforms.functional import to_pil_image
from pathlib import Path
import random
from typing import Any, Callable, List, Optional, Tuple
import torch.optim as optim
from torch import nn, einsum
from torch.autograd import Variable
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
from tqdm import trange

def create_folder(path):
    try:
        os.mkdir(path)
    except OSError as exc:
        if exc.errno != errno.EEXIST:
            raise
        pass

torch.manual_seed(0)

parser = get_parser()
parser.add_argument('--root', default='ICML')
parser.add_argument("--lr", default=1e-4, type=float)
parser.add_argument("--momentum", default=0.9, type=float)
parser.add_argument("--wd", default=1e-2, type=float)
parser.add_argument("--shuffle", default=False, help='shuffles the data when we can the train and val data')
parser.add_argument('--direct', action='store_true', help='use direct sampling for noising and denoising')
parser.add_argument('--epochs', default=100, type=int)
parser.add_argument('--run_command', default='PYTHONPATH=. python Guided/membership_classification.py',
                    help='How to run the script.')
parser.add_argument('--save_every', type=int, default=1)
parser.add_argument('--test_every', type=int, default=10)
parser.add_argument('--optimizer', default='adamw', choices=['sgd', 'adamw'])
parser.add_argument('--use_noise', action='store_true')
parser.add_argument('--fixed_noise', action='store_true')
parser.add_argument('--almost_fixed_noise', action='store_true')
parser.add_argument('--distribution', action='store_true')
parser.add_argument('--load', action='store_true')
parser.add_argument('--remove_bn', action='store_true', default=False)
parser.add_argument('--repeat', type=int, default=1)
parser.add_argument('--use_image', type=int, default=1)
parser.add_argument('--wandb', type=int, default=1)
parser.add_argument('--input_size', type=int, default=64)

parser.add_argument("--optim_lr", default=1e-3, type=float)
parser.add_argument('--optim_max_iters', type=int, default=1)
parser.add_argument("--optim_loss_cutoff", default=0.00001, type=float)
parser.add_argument('--optim_guidance_3', action='store_true', default=False)
parser.add_argument('--optim_original_guidance', action='store_true', default=False)
parser.add_argument("--optim_guidance_3_wt", default=2.0, type=float)
parser.add_argument('--optim_warm_start', action='store_true', default=False)
parser.add_argument('--optim_print', action='store_true', default=False)
parser.add_argument('--optim_aug', action='store_true', default=False)
parser.add_argument('--optim_folder', default='./temp/')
parser.add_argument("--optim_num_steps", nargs="+", default=[1], type=int)
parser.add_argument("--optim_mask_fraction", default=0.5, type=float)


args = parser.parse_args()

# MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond False --diffusion_steps 1000 --image_size 256 --learn_sigma True --noise_schedule linear --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 False --use_scale_shift_norm True"
# PYTHONPATH=. python scripts/load_model.py $MODEL_FLAGS --classifier_scale 10.0 --classifier_path models/256x256_classifier.pt --model_path models/256x256_diffusion_uncond.pt $SAMPLE_FLAGS

# MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond False --diffusion_steps 1000 --image_size 256 --learn_sigma True --noise_schedule linear --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 False --use_scale_shift_norm True"
# PYTHONPATH=. python Guided/Segmentation_mobilenet.py $MODEL_FLAGS --classifier_scale 10.0 --classifier_path models/256x256_classifier.pt --model_path models/256x256_diffusion_uncond.pt $SAMPLE_FLAGS --batch_size 4

class Normalize(object):
    """Normalize a tensor image with mean and standard deviation.
    Args:
        mean (tuple): means for each channel.
        std (tuple): standard deviations for each channel.
    """

    def __init__(self, mean=(0., 0., 0.), std=(1., 1., 1.)):
        self.mean = mean
        self.std = std

    def __call__(self, sample):
        img = sample['image']
        mask = sample['label']
        img = np.array(img).astype(np.float32)
        mask = np.array(mask).astype(np.float32)
        img /= 255.0
        img -= self.mean
        img /= self.std

        return {'image': img,
                'label': mask}

class FixScaleCrop(object):
    def __init__(self, crop_size):
        self.crop_size = crop_size

    def __call__(self, sample):
        img = sample['image']
        mask = sample['label']
        w, h = img.size
        if w > h:
            oh = self.crop_size
            ow = int(1.0 * w * oh / h)
        else:
            ow = self.crop_size
            oh = int(1.0 * h * ow / w)
        img = img.resize((ow, oh), Image.BILINEAR)
        mask = mask.resize((ow, oh), Image.NEAREST)
        # center crop
        w, h = img.size
        x1 = int(round((w - self.crop_size) / 2.))
        y1 = int(round((h - self.crop_size) / 2.))
        img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
        mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))

        return {'image': img,
                'label': mask}

class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""

    def __call__(self, sample):
        # swap color axis because
        # numpy image: H x W x C
        # torch image: C X H X W
        img = sample['image']
        mask = sample['label']
        img = np.array(img).astype(np.float32).transpose((2, 0, 1))
        mask = np.array(mask).astype(np.float32)

        img = torch.from_numpy(img).float()
        mask = torch.from_numpy(mask).float()

        return {'image': img,
                'label': mask}


class COCOSegmentation(Dataset):
    NUM_CLASSES = 21
    CAT_LIST = [0, 5, 2, 16, 9, 44, 6, 3, 17, 62, 21, 67, 18, 19, 4,
                1, 64, 20, 63, 7, 72]

    def __init__(self,
                 base_dir=Path("ICML/coco/"),
                 split='train',
                 year='2017'):
        super().__init__()
        ann_file = os.path.join(
            base_dir, 'annotations/instances_{}{}.json'.format(split, year))
        ids_file = os.path.join(
            Path("./"), '{}_ids_{}.pth'.format(split, year))
        self.img_dir = os.path.join(
            base_dir, 'images/{}{}'.format(split, year))
        self.split = split
        self.coco = COCO(ann_file)
        self.coco_mask = mask
        if os.path.exists(ids_file):
            self.ids = torch.load(ids_file)
        else:
            ids = list(self.coco.imgs.keys())
            self.ids = self._preprocess(ids, ids_file)

        self.composed_transforms = transforms.Compose([
            FixScaleCrop(crop_size=520),
            Normalize(
                mean=(
                    0.485, 0.456, 0.406), std=(
                    0.229, 0.224, 0.225)),
            ToTensor()])

    def __getitem__(self, index):
        _img, _target = self._make_img_gt_point_pair(index)
        sample = {'image': _img, 'label': _target}



        if self.split == "train":
            return self.composed_transforms(sample)
            # sample = TF.resize(sample, (520, 520), interpolation=TF.InterpolationMode.BILINEAR)
            # sample = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(sample)
            # return sample
        elif self.split == 'val':
            return self.composed_transforms(sample)
            # sample = TF.resize(sample, (520, 520), interpolation=TF.InterpolationMode.BILINEAR)
            # sample = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(sample)
            # return sample


    def _make_img_gt_point_pair(self, index):
        coco = self.coco
        img_id = self.ids[index]
        img_metadata = coco.loadImgs(img_id)[0]
        path = img_metadata['file_name']
        _img = Image.open(os.path.join(self.img_dir, path)).convert('RGB')
        cocotarget = coco.loadAnns(coco.getAnnIds(imgIds=img_id))
        _target = Image.fromarray(self._gen_seg_mask(
            cocotarget, img_metadata['height'], img_metadata['width']))

        return _img, _target

    def _preprocess(self, ids, ids_file):
        print("Preprocessing mask, this will take a while. " +
              "But don't worry, it only run once for each split.")
        tbar = trange(len(ids))
        new_ids = []
        for i in tbar:
            img_id = ids[i]
            cocotarget = self.coco.loadAnns(self.coco.getAnnIds(imgIds=img_id))
            img_metadata = self.coco.loadImgs(img_id)[0]
            mask = self._gen_seg_mask(cocotarget, img_metadata['height'],
                                      img_metadata['width'])
            # more than 1k pixels
            if (mask > 0).sum() > 1000:
                new_ids.append(img_id)
            tbar.set_description('Doing: {}/{}, got {} qualified images'.
                                 format(i, len(ids), len(new_ids)))
        print('Found number of qualified images: ', len(new_ids))
        torch.save(new_ids, ids_file)
        return new_ids

    def _gen_seg_mask(self, target, h, w):
        mask = np.zeros((h, w), dtype=np.uint8)
        coco_mask = self.coco_mask
        for instance in target:
            rle = coco_mask.frPyObjects(instance['segmentation'], h, w)
            m = coco_mask.decode(rle)
            cat = instance['category_id']
            if cat in self.CAT_LIST:
                c = self.CAT_LIST.index(cat)
            else:
                continue
            if len(m.shape) < 3:
                mask[:, :] += (mask == 0) * (m * c)
            else:
                mask[:, :] += (mask == 0) * \
                    (((np.sum(m, axis=2)) > 0) * c).astype(np.uint8)
        return mask

    def __len__(self):
        return len(self.ids)

def get_pascal_labels():
    """Load the mapping that associates pascal classes with label colors
    Returns:
        np.ndarray with dimensions (21, 3)
    """
    return np.asarray([[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
                       [0, 0, 128], [128, 0, 128], [
                           0, 128, 128], [128, 128, 128],
                       [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
                       [64, 0, 128], [192, 0, 128], [
                           64, 128, 128], [192, 128, 128],
                       [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
                       [0, 64, 128]])

def decode_seg_map_sequence(label_masks):
    rgb_masks = []
    for label_mask in label_masks:
        rgb_mask = decode_segmap(label_mask)
        rgb_masks.append(rgb_mask)
    rgb_masks = torch.from_numpy(np.array(rgb_masks).transpose([0, 3, 1, 2]))
    return rgb_masks

def decode_segmap(label_mask):
    n_classes = 21
    label_colours = get_pascal_labels()

    r = label_mask.copy()
    g = label_mask.copy()
    b = label_mask.copy()

    for ll in range(0, n_classes):
        r[label_mask == ll] = label_colours[ll, 0]
        g[label_mask == ll] = label_colours[ll, 1]
        b[label_mask == ll] = label_colours[ll, 2]

    rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3))
    rgb[:, :, 0] = r / 255.0
    rgb[:, :, 1] = g / 255.0
    rgb[:, :, 2] = b / 255.0
    return rgb

BATCH_SIZE = args.batch_size
val_set = COCOSegmentation(split='val')
num_class = val_set.NUM_CLASSES
print(num_class)
val_loader = DataLoader(
    val_set,
    batch_size=args.batch_size,
    shuffle=False)

results_folder = args.optim_folder
create_folder(results_folder)

invTrans = transforms.Compose([transforms.Normalize(mean = [ 0., 0., 0. ],
                                                     std = [ 1/0.229, 1/0.224, 1/0.225 ]),
                                transforms.Normalize(mean = [ -0.485, -0.456, -0.406 ],
                                                     std = [ 1., 1., 1. ]),
                               ])


def CrossEntropyLoss(logit, target):
    criterion = nn.CrossEntropyLoss(reduce=False, ignore_index=255).cuda()
    loss = criterion(logit, target.long())

    return loss.mean(dim=[1, 2])

def FocalLoss(logit, target, gamma=2, alpha=0.5):
    n, c, h, w = logit.size()
    criterion = nn.CrossEntropyLoss(reduce=False, ignore_index=255).cuda()
    logpt = -criterion(logit, target.long())

    pt = torch.exp(logpt)
    if alpha is not None:
        logpt *= alpha
    loss = -((1 - pt) ** gamma) * logpt

    return loss.mean(dim=[1, 2])

weights = DeepLabV3_ResNet101_Weights.DEFAULT #FCN_ResNet50_Weights.DEFAULT
model = deeplabv3_resnet101(weights=weights) #fcn_resnet50(weights=weights)
model = model.eval()
Trans = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

for param in model.parameters():
    param.requires_grad = False

criterion = CrossEntropyLoss

class Segmnetation(nn.Module):
    def __init__(self, model, Trans):
        super(Segmnetation, self).__init__()
        self.model = model
        self.trans = Trans

    def forward(self, x):
        map = (x + 1) * 0.5
        map = TF.resize(map, (520, 520), interpolation=TF.InterpolationMode.BILINEAR)
        map = self.trans(map)
        map = self.model(map)
        map = map['out']
        return map

operation_func = Segmnetation(model, Trans)
operation_func = torch.nn.DataParallel(operation_func).cuda()
operation_func.eval()
for param in operation_func.parameters():
    param.requires_grad = False


results_folder = args.optim_folder
create_folder(results_folder)

operation = OptimizerDetails()

seq = []
pre = torch.nn.Sequential(*seq)

operation.num_steps = args.optim_num_steps #[2]
operation.operation_func = operation_func
operation.optimizer = 'Adam'
operation.lr = args.optim_lr #0.01
operation.loss_func = CrossEntropyLoss
operation.max_iters = args.optim_max_iters #00
operation.loss_cutoff = args.optim_loss_cutoff #0.00001
operation.tv_loss = None
operation.guidance_3 = args.optim_guidance_3 #True
operation.original_guidance = args.optim_original_guidance
operation.optim_guidance_3_wt = args.optim_guidance_3_wt
operation.warm_start = args.optim_warm_start #False
operation.print = args.optim_print
operation.print_every = 5
operation.folder = results_folder
if args.optim_aug:
    operation.Aug = pre


operator = Operation(args, operation=operation, shape=[BATCH_SIZE, 3, 256, 256], progressive=True)
cnt = 0

def return_cv2(img, path):
    black = [255, 255, 255]
    img = (img + 1) * 0.5
    utils.save_image(img, path, nrow=1)
    img = cv2.imread(path)
    img = cv2.copyMakeBorder(img, 10, 10, 10, 10, cv2.BORDER_CONSTANT, value=black)
    return img


for batch_ind, sample in enumerate(val_loader):
    _, label = sample['image'], sample['label']
    label = label.cuda()

    _ = invTrans(_)
    utils.save_image(_, f'{results_folder}/og_img_{batch_ind}.png')

    label_save = decode_seg_map_sequence(torch.squeeze(label, 1).detach(
    ).cpu().numpy())
    utils.save_image(label_save, f'{results_folder}/label_{batch_ind}.png')

    output_image = operator.operator(label=None, operated_image=label)
    output = operation_func(output_image)

    output_image = (output_image + 1) * 0.5
    utils.save_image(output_image, f'{results_folder}/output_image_{batch_ind}.png')

    output = decode_seg_map_sequence(torch.max(output, 1)[1].detach(
    ).cpu().numpy())
    utils.save_image(output, f'{results_folder}/output_{batch_ind}.png')

    if batch_ind == 0:
        break

# for batch_ind, sample in enumerate(val_loader):
#     image, label = sample['image'], sample['label']
#
#     image = invTrans(image)
#     image = 2 * image - 1
#
#     output = operation_func(image)
#
#     loss = criterion(output, label)
#
#     label = decode_seg_map_sequence(torch.squeeze(label, 1).detach(
#     ).cpu().numpy())
#
#     output = decode_seg_map_sequence(torch.max(output, 1)[1].detach(
#     ).cpu().numpy())
#
#
#
#     utils.save_image(image, f'{results_folder}/og_img_{batch_ind}.png')
#     utils.save_image(label, f'{results_folder}/label_{batch_ind}.png')
#     utils.save_image(output, f'{results_folder}/output_{batch_ind}.png')
#
#
#
#     print(loss)
#
#     if batch_ind == 2:
#         break

# for batch_ind, sample in enumerate(val_loader):
#     image, label = sample['image'], sample['label']
#     output = model(image)['out']
#
#     loss = criterion(output, label)
#
#     label = decode_seg_map_sequence(torch.squeeze(label, 1).detach(
#     ).cpu().numpy())
#
#     output = decode_seg_map_sequence(torch.max(output, 1)[1].detach(
#     ).cpu().numpy())
#
#     image = invTrans(image)
#
#     utils.save_image(image, f'{results_folder}/og_img_{batch_ind}.png')
#     utils.save_image(label, f'{results_folder}/label_{batch_ind}.png')
#     utils.save_image(output, f'{results_folder}/output_{batch_ind}.png')
#
#     print(loss)
#
#     if batch_ind == 2:
#         break


exit()




resolution_fact = 8

weights = DeepLabV3_ResNet101_Weights.DEFAULT #FCN_ResNet50_Weights.DEFAULT
model = deeplabv3_resnet101(weights=weights) #fcn_resnet50(weights=weights)
Trans = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
model = model.eval()
for param in model.parameters():
    param.requires_grad = False


for param in model.parameters():
    param.requires_grad = False


class Segmnetation(nn.Module):
    def __init__(self, model, Trans):
        super(Segmnetation, self).__init__()
        self.model = model#.backbone
        self.trans = Trans

    def forward(self, x):
        map = (x + 1) * 0.5
        map = TF.resize(map, (520, 520), interpolation=TF.InterpolationMode.BILINEAR)
        map = self.trans(map)
        map = self.model(map)
        map = map['out']
        return map

operation_func = Segmnetation(model, Trans)
operation_func = torch.nn.DataParallel(operation_func).cuda()
operation_func.eval()
for param in operation_func.parameters():
    param.requires_grad = False



def ce_loss(input, target):
    c = nn.CrossEntropyLoss(reduce=False)
    l = c(input, target)
    return l.mean(dim=[1, 2])

def mse_loss(input, target):
    return ((input - target) ** 2).mean(dim=[1, 2, 3])

# operation = [2, operation_func, optim.Adam, 0.008, weighted_ce_loss, 500, 0.01, 1]
# operation = [5, operation_func, optim.Adam, 0.5 , mse_loss, 2000, 0.005, 1]

results_folder = args.optim_folder
create_folder(results_folder)

operation = OptimizerDetails()

seq = []
pre = torch.nn.Sequential(*seq)

operation.num_steps = args.optim_num_steps #[2]
operation.operation_func = operation_func
operation.optimizer = 'Adam'
operation.lr = args.optim_lr #0.01
operation.loss_func = ce_loss
operation.max_iters = args.optim_max_iters #00
operation.loss_cutoff = args.optim_loss_cutoff #0.00001
operation.tv_loss = None
operation.guidance_3 = args.optim_guidance_3 #True
operation.original_guidance = args.optim_original_guidance
operation.optim_guidance_3_wt = args.optim_guidance_3_wt
operation.warm_start = args.optim_warm_start #False
operation.print = args.optim_print
operation.print_every = 5
operation.folder = results_folder
if args.optim_aug:
    operation.Aug = pre


# operation = [2, operation_func, optim.Adam, 0.001, nn.MSELoss(), 1000, 0.001]

operator = Operation(args, operation=operation, shape=[BATCH_SIZE, 3, 256, 256], progressive=True)
cnt = 0

def return_cv2(img, path):
    black = [255, 255, 255]
    img = (img + 1) * 0.5
    utils.save_image(img, path, nrow=1)
    img = cv2.imread(path)
    img = cv2.copyMakeBorder(img, 10, 10, 10, 10, cv2.BORDER_CONSTANT, value=black)
    return img

print('loading the dataset...')
train_dataset, val_dataset = get_train_val_datasets(args)
print('done')
print('splitting the dataset...')
train1, train2 = get_splitted_dataset(dataset=train_dataset,
                                      checkpoint_path='checkpoints/non_equal_split/partitions_train.pt')
val1, val2 = get_splitted_dataset(dataset=val_dataset, checkpoint_path='checkpoints/non_equal_split/partitions_val.pt')
print('done')
train1, train2 = get_loader_from_dataset(args, train1, True), get_loader_from_dataset(args, train2, False)
val1, val2 = get_loader_from_dataset(args, val1, True), get_loader_from_dataset(args, val2, False)




for batch_ind, batch in enumerate(val1):
    image, label = batch
    image, label = image.cuda(), label.cuda()

    with torch.no_grad():
        map = operation_func(image).softmax(dim=1)
        old_map = torch.clone(map)
        num_class = map.shape[1]
        print(map.shape)
        #
        max_vals, max_indices = torch.max(map, 1)
        print(max_indices.shape)
        #
        map = F.one_hot(max_indices, num_classes=num_class)
        map = map.permute(0, 3, 1, 2).float()
        print(map.shape)



    utils.save_image((image + 1) * 0.5, f'{results_folder}/og_img_{batch_ind}.png')

    for i in range(map.shape[1]):
        a = map[:, i: i + 1, :, :]
        b = old_map[:, i: i + 1, :, :]
        print(i)
        print(torch.max(a), torch.min(a))
        print(torch.max(b), torch.min(b))
        utils.save_image(a, f'{results_folder}/target_mask_{batch_ind}_{i}.png')
        utils.save_image(b, f'{results_folder}/target_mask_old_{batch_ind}_{i}.png')

    exit()
    print("Start")
    output = operator.operator(label=label, operated_image=map)
    output = (output + 1) * 0.5
    utils.save_image(output, f'{results_folder}/new_img_{batch_ind}.png')


    with torch.no_grad():
        map = operation_func(output)

    for i in range(map.shape[1]):
        a = map[:, i: i + 1, :, :]
        utils.save_image(a, f'{results_folder}/predict_mask_{batch_ind}_{i}.png')


    if batch_ind == 0:
        break



