import torch.utils.data as data
from evaluate_detection.voc import (VOCDetection4KNN,
                                    make_transforms_large_canvas, CanvasesVOCDetection4Val)
import cv2
from PIL import Image
from evaluate_detection.voc import make_transforms, create_grid_from_images, create_gradiant_cross_grid_images
import torch
import numpy as np
import torchvision.transforms as T
import json
import random


def box_to_img(mask, target, border_width=4):
    if mask is None:
        mask = np.zeros((112, 112, 3))
    h, w, _ = mask.shape
    for box in target['boxes']:
        x_min, y_min, x_max, y_max = list((box * (h - 1)).round().int().numpy())
        cv2.rectangle(mask, (x_min, y_min), (x_max, y_max), (255, 255, 255), border_width)
    return Image.fromarray(mask.astype('uint8'))


def get_annotated_image(img, boxes, border_width=3, mode='draw', bgcolor='white', fg='image'):
    if mode == 'draw':
        image_copy = np.array(img.copy())
        for box in boxes:
            box = box.numpy().astype('int')
            cv2.rectangle(image_copy, (box[0], box[1]), (box[2], box[3]), (255, 0, 0), border_width)
    elif mode == 'keep':
        image_copy = np.array(Image.new('RGB', (img.shape[1], img.shape[0]), color=bgcolor))

        for box in boxes:
            box = box.numpy().astype('int')
            if fg == 'image':
                image_copy[box[1]:box[3], box[0]:box[2]] = img[box[1]:box[3], box[0]:box[2]]
            elif fg == 'white':
                image_copy[box[1]:box[3], box[0]:box[2]] = 255

    return image_copy


class CanvasDataset4Val(data.Dataset):
    def __init__(self, pascal_path='pascal-5i', years=("2012",), random_example=False, feature_name='features_vit-laion2b_pixel-level_val', **kwargs):
        self.train_ds = VOCDetection4KNN(pascal_path, years, image_sets=['train'], transforms=None,
                                         keep_single_objs_only=1, filter_by_mask_size=1)
        self.val_ds = VOCDetection4KNN(pascal_path, years, image_sets=['val'], transforms=None,
                                       keep_single_objs_only=1, filter_by_mask_size=1)
        self.background_transforms = T.Compose([
            T.Resize((224, 224)),
            T.Compose([
                T.ToTensor(),
                T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])
        ])
        self.data_path = pascal_path
        self.transforms = make_transforms('val')
        self.random_example = random_example
        self.feature_name = f'{feature_name}_all_detection'
        self.get_top_all_images_val = self.get_top_all_images_val()

    def __len__(self):
        return len(self.val_ds)

    def get_top_all_images_val(self):
        with open(f"{self.data_path}/VOC2012/{self.feature_name}/folder_cls_top_all-similarity.json") as f:
            images_top_all = json.load(f)

        images_top_all_new = {}
        for img_name in images_top_all:  # Correctly iterate over the keys
            if img_name not in images_top_all_new:
                images_top_all_new[img_name] = {}

            images_top_all_new[img_name]['top50'] = images_top_all[img_name]

        return images_top_all_new

    def __getitem__(self, idx):
        grid_stack = torch.tensor([]).cuda()
        query_name = self.val_ds.image_set[idx]

        if self.random_example:
            # set random seed as idx for reproducibility
            random_seed = idx
            random.seed(random_seed)
            support_idx_list = list(range(len(self.get_top_all_images_val[query_name]['top50'])))
            support_idx = random.choice(support_idx_list)
            support_idx_list.remove(support_idx)
            support_name = self.get_top_all_images_val[query_name]['top50'][support_idx]
            support_name_datastore = self.get_top_all_images_val[query_name]['top50'][0]
        else:
            support_name = self.get_top_all_images_val[query_name]['top50'][0]
            support_name_datastore = support_name
        query_image, query_target = self.val_ds.get_item_by_filename(query_name)
        label = query_target['labels'].numpy()[0]

        support_image, support_target = self.train_ds.get_item_by_filename(support_name)
        support_label = support_target['labels'].numpy()[0]

        boxes = support_target['boxes'][torch.where(support_target['labels'] == support_label)[0]]
        support_image_copy = get_annotated_image(np.array(support_image), boxes, border_width=-1, mode='keep', bgcolor='black', fg='white')
        support_image_copy_pil = Image.fromarray(support_image_copy)

        boxes = query_target['boxes'][torch.where(query_target['labels'] == label)[0]]
        query_image_copy = get_annotated_image(np.array(query_image), boxes, border_width=-1, mode='keep', bgcolor='black', fg='white')
        query_image_copy_pil = Image.fromarray(query_image_copy)

        query_image_ten = self.transforms(query_image, None)[0]
        query_target_ten = self.transforms(query_image_copy_pil, None)[0]
        support_target_ten = self.transforms(support_image_copy_pil, None)[0]
        support_image_ten = self.transforms(support_image, None)[0]

        background_image = Image.new('RGB', (224, 224), color='white')
        background_image = self.background_transforms(background_image)
        grid = create_grid_from_images(background_image, support_image_ten, support_target_ten, query_image_ten,
                                       query_target_ten)

        if len(grid_stack) == 0:
            grid_stack = grid
        else:
            grid_stack = torch.cat((grid_stack, grid))

        return {'query_img': query_image_ten,
                'query_mask': query_target_ten,
                'support_img': support_image_ten,
                'support_mask': support_target_ten,
                'support_name': support_name,
                'query_name': query_name,
                'support_name_datastore': support_name_datastore,
                'grid_stack': grid_stack}


class LargeCanvasDataset4Val(data.Dataset):
    def __init__(self, pascal_path='pascal-5i', years=("2012",), random_example=False, n_shot=2, **kwargs):
        self.train_ds = CanvasesVOCDetection4Val(pascal_path, years, image_sets=['train'], transforms=None,
                                                 keep_single_objs_only=1, filter_by_mask_size=1)
        self.val_ds = CanvasesVOCDetection4Val(pascal_path, years, image_sets=['val'], transforms=None,
                                               keep_single_objs_only=1, filter_by_mask_size=1)
        self.background_transforms = T.Compose([
            T.Resize((224, 224)),
            T.Compose([
                T.ToTensor(),
                T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])
        ])
        self.data_path = pascal_path
        self.transforms = make_transforms_large_canvas('val')
        self.random_example = random_example
        self.feature_name = 'features_vit-laion2b_pixel-level_val_all_detection'  # feature name
        self.get_top_all_images_val = self.get_top_all_images_val()
        self.n_shot = n_shot

    def __len__(self):
        return len(self.val_ds)

    def get_top_all_images_val(self):
        # load the json file for detection task of same object class following MAE-VQGAN setting.
        with open(f"{self.data_path}/VOC2012/{self.feature_name}/folder_cls_top_all-similarity.json") as f:
            images_top_all = json.load(f)

        images_top_all_new = {}
        for img_name in images_top_all:  # Correctly iterate over the keys
            if img_name not in images_top_all_new:
                images_top_all_new[img_name] = {}

            images_top_all_new[img_name]['top50'] = images_top_all[img_name]

        return images_top_all_new

    def sample_episode(self, idx, sim_idx):
        """Returns the index of the query, support and class."""
        query_2_find_support_name = self.val_ds.image_set[idx]
        support_name = self.get_top_all_images_val[query_2_find_support_name]['top50'][sim_idx]
        support_name_datastore = self.get_top_all_images_val[query_2_find_support_name]['top50'][0]
        query_name = query_2_find_support_name

        return query_name, support_name, query_2_find_support_name, support_name_datastore

    def __getitem__(self, idx):
        grid_stack = torch.tensor([]).cuda()

        support_imgs = []
        support_masks = []

        # set random seed as idx for reproducibility
        random_seed = idx
        random.seed(random_seed)

        if self.random_example:
            # initialize the random support set.
            support_idx_list = list(range(len(self.get_top_all_images_val[self.val_ds.image_set[idx]]['top50'])))

        for sim_idx in range(self.n_shot):
            query_image, query_target = self.val_ds[idx]
            if self.random_example:
                support_idx = random.choice(support_idx_list)
                support_idx_list.remove(support_idx)
                query_name, support_name, query_2_find_support_name, support_name_datastore = (
                    self.sample_episode(idx, support_idx))
            else:
                query_name, support_name, query_2_find_support_name, support_name_datastore = (
                    self.sample_episode(idx, sim_idx))

            label = query_target['labels'].numpy()[0]

            support_image, support_target = self.train_ds.get_item_by_filename(support_name)
            support_label = support_target['labels'].numpy()[0]

            boxes = support_target['boxes'][torch.where(support_target['labels'] == support_label)[0]]
            support_image_copy = get_annotated_image(np.array(support_image), boxes, border_width=-1, mode='keep',
                                                     bgcolor='black', fg='white')
            support_image_copy_pil = Image.fromarray(support_image_copy)

            boxes = query_target['boxes'][torch.where(query_target['labels'] == label)[0]]
            query_image_copy = get_annotated_image(np.array(query_image), boxes, border_width=-1, mode='keep',
                                                   bgcolor='black', fg='white')
            query_image_copy_pil = Image.fromarray(query_image_copy)

            query_image_ten = self.transforms(query_image, None)[0]
            query_target_ten = self.transforms(query_image_copy_pil, None)[0]
            support_target_ten = self.transforms(support_image_copy_pil, None)[0]
            support_image_ten = self.transforms(support_image, None)[0]

            support_imgs.append(support_image_ten)
            support_masks.append(support_target_ten)

        background_image = Image.new('RGB', (224, 224), color='white')
        background_image = self.background_transforms(background_image)
        grid = create_gradiant_cross_grid_images(support_imgs, support_masks, query_image_ten, query_target_ten, background_image)

        if len(grid_stack) == 0:
            grid_stack = grid
        else:
            grid_stack = torch.cat((grid_stack, grid))

        batch = {'query_img': query_image_ten,
                 'query_mask': query_target_ten,
                 'support_img': support_imgs,
                 'support_mask': support_masks,
                 'query_name': query_2_find_support_name,
                 'grid_stack': grid_stack}

        return batch


class CanvasDataset4ValKNN(data.Dataset):
    def __init__(self, pascal_path='pascal-5i', years=("2012",), random_example=False, n_shot=4, feature_name='features_vit-laion2b_pixel-level_val', **kwargs):
        self.train_ds = VOCDetection4KNN(pascal_path, years, image_sets=['train'], transforms=None,
                                         keep_single_objs_only=1, filter_by_mask_size=1)
        self.val_ds = VOCDetection4KNN(pascal_path, years, image_sets=['val'], transforms=None,
                                       keep_single_objs_only=1, filter_by_mask_size=1)
        self.background_transforms = T.Compose([
            T.Resize((224, 224)),
            T.Compose([
                T.ToTensor(),
                T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])
        ])
        self.data_path = pascal_path
        self.n_shot = n_shot
        self.transforms = make_transforms('val')
        self.random_example = random_example
        self.feature_name = f'{feature_name}_all_detection'
        self.get_top_all_images_val = self.get_top_all_images_val()

    def __len__(self):
        return len(self.val_ds)

    def get_top_all_images_val(self):
        with open(f"{self.data_path}/VOC2012/{self.feature_name}/folder_cls_top_all-similarity.json") as f:
            images_top_all = json.load(f)

        images_top_all_new = {}
        for img_name in images_top_all:  # Correctly iterate over the keys
            if img_name not in images_top_all_new:
                images_top_all_new[img_name] = {}

            images_top_all_new[img_name]['top50'] = images_top_all[img_name]

        return images_top_all_new

    def sample_episode_for_training_retriever_query(self, idx, sim_idx):
        """Returns the index of the query, support and class."""
        query_2_find_support_name = self.val_ds.image_set[idx]
        # print('query_2_find_support_name: ', query_2_find_support_name)
        support_name = self.get_top_all_images_val[query_2_find_support_name]['top50'][sim_idx]
        support_name_datastore = self.get_top_all_images_val[query_2_find_support_name]['top50'][0]
        query_name = query_2_find_support_name

        return query_name, support_name, query_2_find_support_name, support_name_datastore

    def __getitem__(self, idx):
        batch_list = []

        if self.random_example:
            # set random seed as idx for reproducibility
            random_seed = idx
            random.seed(random_seed)
            support_idx_list = list(range(len(self.get_top_all_images_val[self.val_ds.image_set[idx]]['top50'])))

        for sim_idx in range(self.n_shot):
            grid_stack = torch.tensor([]).cuda()
            if self.random_example:
                support_idx = random.choice(support_idx_list)
                support_idx_list.remove(support_idx)
                query_name, support_name, query_2_find_support_name, support_name_datastore = (
                    self.sample_episode_for_training_retriever_query(idx, support_idx))
            else:
                query_name, support_name, query_2_find_support_name, support_name_datastore = (
                    self.sample_episode_for_training_retriever_query(idx, sim_idx))

            query_image, query_target = self.train_ds.get_item_by_filename(query_name)
            label = query_target['labels'].numpy()[0]

            support_image, support_target = self.train_ds.get_item_by_filename(support_name)
            support_label = support_target['labels'].numpy()[0]

            boxes = support_target['boxes'][torch.where(support_target['labels'] == support_label)[0]]
            support_image_copy = get_annotated_image(np.array(support_image), boxes, border_width=-1, mode='keep', bgcolor='black', fg='white')
            support_image_copy_pil = Image.fromarray(support_image_copy)

            boxes = query_target['boxes'][torch.where(query_target['labels'] == label)[0]]
            query_image_copy = get_annotated_image(np.array(query_image), boxes, border_width=-1, mode='keep', bgcolor='black', fg='white')
            query_image_copy_pil = Image.fromarray(query_image_copy)

            query_image_ten = self.transforms(query_image, None)[0]
            query_target_ten = self.transforms(query_image_copy_pil, None)[0]
            support_target_ten = self.transforms(support_image_copy_pil, None)[0]
            support_image_ten = self.transforms(support_image, None)[0]

            background_image = Image.new('RGB', (224, 224), color='white')
            background_image = self.background_transforms(background_image)
            grid = create_grid_from_images(background_image, support_image_ten, support_target_ten, query_image_ten,
                                             query_target_ten)

            if len(grid_stack) == 0:
                grid_stack = grid
            else:
                grid_stack = torch.cat((grid_stack, grid))

            shot_batch = {'query_img': query_image_ten,
                          'query_mask': query_target_ten,
                          'support_img': support_image_ten,
                          'support_mask': support_target_ten,
                          'support_name': support_name,
                          'query_name_val': query_2_find_support_name,
                          'support_name_datastore': support_name_datastore,
                          'grid_stack': grid_stack}

            batch_list.append(shot_batch)

        return batch_list


class CanvasesDataset4Val(data.Dataset):
    def __init__(self, pascal_path='pascal-5i', years=("2012",), random=False, n_shot=2, **kwargs):
        self.train_ds = CanvasesVOCDetection4Val(pascal_path, years, image_sets=['train'], transforms=None,
                                         keep_single_objs_only=1, filter_by_mask_size=1)
        self.val_ds = CanvasesVOCDetection4Val(pascal_path, years, image_sets=['val'], transforms=None,
                                       keep_single_objs_only=1, filter_by_mask_size=1)
        self.background_transforms = T.Compose([
            T.Resize((224, 224)),
            T.Compose([
                T.ToTensor(),
                T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])
        ])
        self.data_path = pascal_path
        self.transforms = make_transforms('val')
        self.random = random
        self.n_shot = n_shot
        self.feature_name = 'features_vit-laion2b_pixel-level_val_all_detection'
        self.get_top_all_images_val = self.get_top_all_images_val()

    def __len__(self):
        return len(self.val_ds)

    def get_top_all_images_val(self):
        with open(f"{self.data_path}/VOC2012/{self.feature_name}/folder_cls_top_all-similarity.json") as f:
            images_top_all = json.load(f)

        images_top_all_new = {}
        for img_name in images_top_all:  # Correctly iterate over the keys
            if img_name not in images_top_all_new:
                images_top_all_new[img_name] = {}

            images_top_all_new[img_name]['top50'] = images_top_all[img_name]

        return images_top_all_new

    def sample_episode(self, idx, sim_idx):
        """Returns the index of the query, support and class."""
        query_2_find_support_name = self.val_ds.image_set[idx]
        support_name = self.get_top_all_images_val[query_2_find_support_name]['top50'][sim_idx]
        support_name_datastore = self.get_top_all_images_val[query_2_find_support_name]['top50'][0]
        query_name = query_2_find_support_name

        return query_name, support_name, query_2_find_support_name, support_name_datastore

    def __getitem__(self, idx):
        batch_list = []
        for sim_idx in range(self.n_shot):
            grid_stack = torch.tensor([]).cuda()

            query_image, query_target = self.val_ds[idx]

            query_name, support_name, query_2_find_support_name, support_name_datastore = (
                self.sample_episode(idx, sim_idx))
            label = query_target['labels'].numpy()[0]

            support_image, support_target = self.train_ds.get_item_by_filename(support_name)
            support_label = support_target['labels'].numpy()[0]


            boxes = support_target['boxes'][torch.where(support_target['labels'] == support_label)[0]]
            support_image_copy = get_annotated_image(np.array(support_image), boxes, border_width=-1, mode='keep', bgcolor='black', fg='white')
            support_image_copy_pil = Image.fromarray(support_image_copy)

            boxes = query_target['boxes'][torch.where(query_target['labels'] == label)[0]]
            query_image_copy = get_annotated_image(np.array(query_image), boxes, border_width=-1, mode='keep', bgcolor='black', fg='white')
            query_image_copy_pil = Image.fromarray(query_image_copy)

            query_image_ten = self.transforms(query_image, None)[0]
            query_target_ten = self.transforms(query_image_copy_pil, None)[0]
            support_target_ten = self.transforms(support_image_copy_pil, None)[0]
            support_image_ten = self.transforms(support_image, None)[0]

            background_image = Image.new('RGB', (224, 224), color='white')
            background_image = self.background_transforms(background_image)
            grid = create_grid_from_images(background_image, support_image_ten, support_target_ten, query_image_ten,
                                             query_target_ten)

            if len(grid_stack) == 0:
                grid_stack = grid
            else:
                grid_stack = torch.cat((grid_stack, grid))

            shot_batch = {'query_img': query_image_ten,
                          'query_mask': query_target_ten,
                          'support_img': support_image_ten,
                          'support_mask': support_target_ten,
                          'support_name': support_name,
                          'query_name': query_2_find_support_name,
                          'support_name_datastore': support_name_datastore,
                          'query_class': label,
                          'grid_stack': grid_stack}

            batch_list.append(shot_batch)

        return batch_list
