"""Based on https://github.com/Seokju-Cho/Volumetric-Aggregation-Transformer/blob/main/data/pascal.py
"""
import os
from PIL import Image
import numpy as np
import torch
from torch.utils.data import Dataset
import json
import random


class DatasetPASCAL(Dataset):
    def __init__(self, datapath, fold, split, image_transform, mask_transform, padding: bool = 1,
                 use_original_imgsize: bool = False, flipped_order: bool = False,
                 reverse_support_and_query: bool = False, random: bool = False, ensemble: bool = False,
                 purple: bool = False, cluster: bool = False, feature_name: str = 'features_vit-laion2b_no_cls_trn',
                 percentage: str = '', seed: int = 0, arr: bool = False, n_shot: int = 16,
                 cls_base: bool = False, random_example: bool = False, anchor_mode: str = 'query'):
        self.fold = fold
        self.split = split
        self.nfolds = 4
        self.flipped_order = flipped_order
        self.nclass = 20  # 20
        self.ncluster = 200
        self.padding = padding
        self.random = random
        self.ensemble = ensemble
        self.purple = purple
        self.cluster = cluster
        self.use_original_imgsize = use_original_imgsize
        self.cls_base = cls_base
        self.anchor_mode = anchor_mode
        if self.anchor_mode == 'itself':
            self.sample_episode = self.sample_episode_for_training_retriever_itself
        elif self.anchor_mode == 'query':
            self.sample_episode = self.sample_episode_for_training_retriever_ori_query
        elif self.anchor_mode == 'random':
            self.sample_episode = self.sample_episode_for_training_retriever_random_anchor
        elif self.anchor_mode == 'seq':
            self.sample_episode = self.sample_episode_for_training_retriever_sequence
        else:
            raise ValueError('anchor_mode must be in ["itself", "query", "random", "seq"]!')

        self.data_path = datapath
        self.img_path = os.path.join(datapath, 'VOC2012/JPEGImages/')
        self.ann_path = os.path.join(datapath, 'VOC2012/SegmentationClassAug/')
        self.image_transform = image_transform
        self.reverse_support_and_query = reverse_support_and_query
        self.mask_transform = mask_transform
        self.n_shot = n_shot
        self.random_example = random_example

        self.class_ids = self.build_class_ids()
        self.img_metadata_val = self.build_img_metadata('val')
        self.img_metadata_trn = self.build_img_metadata('trn')

        self.feature_name = feature_name
        self.seed = seed
        self.percentage = percentage
        self.get_top50_images_val = self.get_top50_images_val()
        self.images_top50_trn = self.get_top50_images_trn()
        self.arr = arr

    def __len__(self):
        return len(self.img_metadata_val)

    def get_top50_images_val(self):
        with open(f"./pascal-5i/VOC2012/{self.feature_name[:-4]}_val/folder{self.fold}_true_cls_base_top_all-similarity.json") as f:
            images_top50 = json.load(f)

        images_top50_new = {}
        for img_name, img_class in self.img_metadata_val:
            img_class_key = f"{int(img_class)+1:02d}"
            img_name_cls = f'{img_name}__{img_class_key}'
            if img_name_cls not in images_top50_new:
                images_top50_new[img_name_cls] = {}
            images_top50_new[img_name_cls]['top50'] = images_top50[img_name_cls]
            images_top50_new[img_name_cls]['class'] = img_class

        return images_top50_new

    def get_top50_images_trn(self):
        images_top50_new = {}
        for img_name, img_class in self.img_metadata_trn:
            if img_name not in images_top50_new:
                images_top50_new[img_name] = {}

            images_top50_new[img_name]['class'] = img_class

        return images_top50_new

    def create_gradiant_grid_images(self, support_img, support_mask, query_img, query_mask, arr):
        # create grid image for suppot images and query image.
        canvas = torch.ones((support_img.shape[0], 2 * support_img.shape[1] + 2 * self.padding,
                             2 * support_img.shape[2] + 2 * self.padding))

        content_list = [support_img, support_mask, query_img, query_mask]

        if arr == 'a1':
            support_img = content_list[0]
            support_mask = content_list[1]
            query_img = content_list[2]
            query_mask = content_list[3]

        elif arr == 'a2':
            support_img = content_list[1]
            support_mask = content_list[0]
            query_img = content_list[3]
            query_mask = content_list[2]

        elif arr == 'a3':
            support_img = content_list[3]
            support_mask = content_list[2]
            query_img = content_list[1]
            query_mask = content_list[0]

        elif arr == 'a4':
            support_img = content_list[2]
            support_mask = content_list[3]
            query_img = content_list[0]
            query_mask = content_list[1]

        elif arr == 'a5':
            support_img = content_list[1]
            support_mask = content_list[3]
            query_img = content_list[0]
            query_mask = content_list[2]

        elif arr == 'a6':
            support_img = content_list[3]
            support_mask = content_list[1]
            query_img = content_list[2]
            query_mask = content_list[0]

        elif arr == 'a7':
            support_img = content_list[2]
            support_mask = content_list[0]
            query_img = content_list[3]
            query_mask = content_list[1]

        elif arr == 'a8':
            support_img = content_list[0]
            support_mask = content_list[2]
            query_img = content_list[1]
            query_mask = content_list[3]

        canvas[:, :support_img.shape[1], :support_img.shape[2]] = support_img
        canvas[:, -query_img.shape[1]:, :query_img.shape[2]] = query_img
        canvas[:, :support_img.shape[1], -support_img.shape[2]:] = support_mask
        canvas[:, -query_img.shape[1]:, -support_img.shape[2]:] = query_mask

        return canvas

    def __getitem__(self, idx):
        batch_list = []
        random_seed = idx
        random.seed(random_seed)

        for sim_idx in range(self.n_shot):
            grid_stack = torch.tensor([]).cuda()
            query_name, support_name, class_sample_query, class_sample_support, query_2_find_support_name, support_name_datastore, query_class_datastore = (
                self.sample_episode(idx, sim_idx))
            query_img, query_cmask, support_img, support_cmask, org_qry_imsize = self.load_frame(query_name,
                                                                                                 support_name)
            if self.image_transform:
                query_img = self.image_transform(query_img)
                query_mask, query_ignore_idx = self.extract_ignore_idx(query_cmask, class_sample_query,
                                                                       purple=self.purple)
            if self.mask_transform:
                query_mask = self.mask_transform(query_mask)

            # The support image no need for transformation.
            if self.image_transform:
                support_img = self.image_transform(support_img)
            support_mask, support_ignore_idx = self.extract_ignore_idx(support_cmask, class_sample_support,
                                                                       purple=self.purple)
            if self.mask_transform:
                support_mask = self.mask_transform(support_mask)

            grid = self.create_gradiant_grid_images(support_img, support_mask, query_img, query_mask, self.arr)

            if len(grid_stack) == 0:
                grid_stack = grid
            else:
                grid_stack = torch.cat((grid_stack, grid))

            shot_batch = {'query_img': query_img,
                          'query_mask': query_mask,
                          'support_img': support_img,
                          'support_mask': support_mask,
                          'support_name': support_name,
                          'suppor_class': class_sample_support,
                          'query_name_val': query_2_find_support_name,
                          'support_name_datastore': support_name_datastore,
                          'query_class_datastore': query_class_datastore,
                          'grid_stack': grid_stack
                          }
            batch_list.append(shot_batch)

        return batch_list

    def extract_ignore_idx(self, mask, class_id, purple):
        mask = np.array(mask)
        boundary = np.floor(mask / 255.)
        if not purple:
            mask[mask != class_id + 1] = 0
            mask[mask == class_id + 1] = 255
            return Image.fromarray(mask), boundary

    def load_frame(self, query_name, support_name):
        query_img = self.read_img(query_name)
        query_mask = self.read_mask(query_name)
        support_img = self.read_img(support_name)
        support_mask = self.read_mask(support_name)
        org_qry_imsize = query_img.size

        return query_img, query_mask, support_img, support_mask, org_qry_imsize

    def read_mask(self, img_name):
        r"""Return segmentation mask in PIL Image"""
        mask = Image.open(os.path.join(self.ann_path, img_name) + '.png')
        return mask

    def read_img(self, img_name):
        r"""Return RGB image in PIL Image"""
        return Image.open(os.path.join(self.img_path, img_name) + '.jpg')
    
    def sample_episode_for_training_retriever_ori_query(self, idx, sim_idx):
        """The default setting of PANICL."""
        query_2_find_support_name, query_2_find_support_name_class_sample = self.img_metadata_val[idx]
        class_sample = query_2_find_support_name_class_sample
        # print('query_2_find_support_name: ', query_2_find_support_name)
        query_2_find_support_name_class_sample_key = f"{int(query_2_find_support_name_class_sample)+1:02d}"
        key = f'{query_2_find_support_name}__{query_2_find_support_name_class_sample_key}'
        support_name_key = self.images_top50_val[key]['top50'][sim_idx]
        # support_class = self.images_top50_trn[support_name]['class']
        support_name = support_name_key.split('__')[0]
        support_class = int(support_name_key.split('__')[1]) - 1
        assert support_class == class_sample
        support_name_datastore = self.images_top50_val[key]['top50'][0].split('__')[0]
        query_name = query_2_find_support_name
        query_class_datastore = query_2_find_support_name_class_sample

        return query_name, support_name, class_sample, support_class, query_2_find_support_name, support_name_datastore, query_class_datastore

    def sample_episode_for_training_retriever_random_anchor(self, idx, sim_idx):
        """There setting of random."""
        query_2_find_support_name, query_2_find_support_name_class_sample = self.img_metadata_val[idx]
        class_sample = query_2_find_support_name_class_sample

        query_2_find_support_name_class_sample_key = f"{int(query_2_find_support_name_class_sample) + 1:02d}"
        key = f'{query_2_find_support_name}__{query_2_find_support_name_class_sample_key}'
        query_name_key = random.sample(self.images_top50_val[key]['top50'], 1)[0]
        # print("support_name_key", support_name_key)
        query_name = query_name_key.split('__')[0]
        query_class = int(query_name_key.split('__')[1]) - 1
        support_class = query_class
        assert query_class == class_sample
        support_name_datastore = self.images_top50_val[key]['top50'][0].split('__')[0]
        support_name = support_name_datastore
        query_class_datastore = query_2_find_support_name_class_sample

        return query_name, support_name, class_sample, support_class, query_2_find_support_name, support_name_datastore, query_class_datastore

    def sample_episode_for_training_retriever_sequence(self, idx, sim_idx):
        """There setting of seq."""
        query_2_find_support_name, query_2_find_support_name_class_sample = self.img_metadata_val[idx]
        class_sample = query_2_find_support_name_class_sample

        query_2_find_support_name_class_sample_key = f"{int(query_2_find_support_name_class_sample) + 1:02d}"
        key = f'{query_2_find_support_name}__{query_2_find_support_name_class_sample_key}'

        support_name_key = self.images_top50_val[key]['top50'][0]
        support_name = support_name_key.split('__')[0]
        # support_class = int(support_name_key.split('__')[1]) - 1
        support_name_datastore = self.images_top50_val[key]['top50'][0].split('__')[0]

        query_name_key = self.images_top50_val[key]['top50'][sim_idx+1]
        query_name = query_name_key.split('__')[0]
        support_class = class_sample
        query_class_datastore = query_2_find_support_name_class_sample

        return query_name, support_name, class_sample, support_class, query_2_find_support_name, support_name_datastore, query_class_datastore

    def sample_episode_for_training_retriever_itself(self, idx, sim_idx):
        """There setting of itself."""
        query_2_find_support_name, query_2_find_support_name_class_sample = self.img_metadata_val[idx]
        class_sample = query_2_find_support_name_class_sample

        query_2_find_support_name_class_sample_key = f"{int(query_2_find_support_name_class_sample) + 1:02d}"
        key = f'{query_2_find_support_name}__{query_2_find_support_name_class_sample_key}'

        support_name_key = self.images_top50_val[key]['top50'][sim_idx]
        support_name = support_name_key.split('__')[0]
        # support_class = int(support_name_key.split('__')[1]) - 1
        support_name_datastore = self.images_top50_val[key]['top50'][0].split('__')[0]
        query_name = support_name
        support_class = class_sample
        query_class_datastore = query_2_find_support_name_class_sample

        return query_name, support_name, class_sample, support_class, query_2_find_support_name, support_name_datastore, query_class_datastore

    def build_class_ids(self):
        nclass_trn = self.nclass // self.nfolds
        class_ids_val = [self.fold * nclass_trn + i for i in range(nclass_trn)]
        return class_ids_val

    def build_img_metadata(self, split):

        def read_metadata(split, fold_id):
            cwd = './evaluate'
            fold_n_metadata_path = os.path.join(cwd, 'splits/pascal/%s/fold%d.txt' % (split, fold_id))

            with open(fold_n_metadata_path, 'r') as f:
                fold_n_metadata = f.read().split('\n')[:-1]

            new_fold_n_metadata = [[data.split('__')[0], int(data.split('__')[1]) - 1] for data in fold_n_metadata]

            return new_fold_n_metadata

        img_metadata = read_metadata(split, self.fold)
        print('length of self.img_metadata_val: ', len(img_metadata))

        print('Total (%s) images are : %d' % (split, len(img_metadata)))

        return img_metadata
