"""
Based on https://github.com/Seokju-Cho/Volumetric-Aggregation-Transformer/blob/main/data/pascal.py
"""
import os
from PIL import Image
from scipy.io import loadmat
import numpy as np
import torch
from torch.utils.data import Dataset
import json


class DatasetCOCO2PASCAL(Dataset):
    def __init__(self, pascal_datapath, coco_datapath, fold, split, image_transform, mask_transform, padding: bool = 1,
                 use_original_imgsize: bool = False, flipped_order: bool = False,
                 reverse_support_and_query: bool = False, random: bool = False, ensemble: bool = False,
                 purple: bool = False, cluster: bool = False, feature_name: str = 'features_vit-laion2b_no_cls_trn',
                 percentage: str = '', seed: int = 0, mode: str = '', arr: str = 'a1', n_shot: int = 16):
        self.fold = fold
        self.split = split
        self.nfolds = 4
        self.flipped_order = flipped_order
        self.nclass = 20  # 20
        self.ncluster = 200
        self.padding = padding
        self.random = random
        self.ensemble = ensemble
        self.purple = purple
        self.cluster = cluster
        self.use_original_imgsize = use_original_imgsize

        self.pascal_img_path = os.path.join(pascal_datapath, 'VOC2012/JPEGImages/')
        self.pascal_ann_path = os.path.join(pascal_datapath, 'VOC2012/SegmentationClassAug/')
        self.coco_img_path = os.path.join(coco_datapath, 'train2014/')
        self.coco_ann_path = os.path.join(coco_datapath, 'annotations/train2014/')
        self.image_transform = image_transform
        self.reverse_support_and_query = reverse_support_and_query
        self.mask_transform = mask_transform
        self.n_shot = n_shot

        self.class_ids = self.build_class_ids()
        self.img_metadata_val = self.build_img_metadata('val')
        self.img_metadata_trn = self.build_coco_img_metadata('trn')
        self.feature_name = feature_name
        self.seed = seed
        self.percentage = percentage
        self.images_top50_val = self.get_top50_images_for_validation()
        self.images_top50_trn = self.get_top50_images_trn()
        self.mode = mode
        self.arr = arr

    def __len__(self):
        return len(self.img_metadata_val)

    def get_top50_images_for_validation(self):
        print('feature name for val: ', self.feature_name[:-4] + '_val')
        with open(f"./pascal-5i/VOC2012/{self.feature_name[:-4]}_val/folder{self.fold}_new_coco_top50-similarity.json") as f:
            images_top50 = json.load(f)

        images_top50_new = {}
        for img_name, img_class in self.img_metadata_val:
            if img_name not in images_top50_new:
                images_top50_new[img_name] = {}

            images_top50_new[img_name]['top50'] = images_top50[img_name]
            images_top50_new[img_name]['class'] = img_class

        return images_top50_new

    def get_top50_images_trn(self):
        images_top50_new = {}
        for img_name, img_class in self.img_metadata_trn:
            if img_name not in images_top50_new:
                images_top50_new[img_name] = {}

            images_top50_new[img_name]['class'] = img_class

        return images_top50_new

    def create_gradiant_grid_images(self, support_img, support_mask, query_img, query_mask, arr):
        # create grid image for suppot images and query image.
        canvas = torch.ones((support_img.shape[0], 2 * support_img.shape[1] + 2 * self.padding,
                             2 * support_img.shape[2] + 2 * self.padding))

        content_list = [support_img, support_mask, query_img, query_mask]

        if arr == 'a1':
            support_img = content_list[0]
            support_mask = content_list[1]
            query_img = content_list[2]
            query_mask = content_list[3]

        elif arr == 'a2':
            support_img = content_list[1]
            support_mask = content_list[0]
            query_img = content_list[3]
            query_mask = content_list[2]

        elif arr == 'a3':
            support_img = content_list[3]
            support_mask = content_list[2]
            query_img = content_list[1]
            query_mask = content_list[0]

        elif arr == 'a4':
            support_img = content_list[2]
            support_mask = content_list[3]
            query_img = content_list[0]
            query_mask = content_list[1]

        elif arr == 'a5':
            support_img = content_list[1]
            support_mask = content_list[3]
            query_img = content_list[0]
            query_mask = content_list[2]

        elif arr == 'a6':
            support_img = content_list[3]
            support_mask = content_list[1]
            query_img = content_list[2]
            query_mask = content_list[0]

        elif arr == 'a7':
            support_img = content_list[2]
            support_mask = content_list[0]
            query_img = content_list[3]
            query_mask = content_list[1]

        elif arr == 'a8':
            support_img = content_list[0]
            support_mask = content_list[2]
            query_img = content_list[1]
            query_mask = content_list[3]

        canvas[:, :support_img.shape[1], :support_img.shape[2]] = support_img
        canvas[:, -query_img.shape[1]:, :query_img.shape[2]] = query_img
        canvas[:, :support_img.shape[1], -support_img.shape[2]:] = support_mask
        canvas[:, -query_img.shape[1]:, -support_img.shape[2]:] = query_mask

        return canvas

    def __getitem__(self, idx):
        batch_list = []
        for sim_idx in range(self.n_shot):
            grid_stack = torch.tensor([]).cuda()
            query_name, support_name, class_sample_query, class_sample_support, query_2_find_support_name, support_name_datastore, query_class_datastore = (
                self.sample_episode_for_training_retriever_itself(idx, sim_idx))
            query_img, query_cmask, support_img, support_cmask, org_qry_imsize = self.load_frame(query_name,
                                                                                                 support_name)

            if self.image_transform:
                query_img = self.image_transform(query_img)
                query_mask, query_ignore_idx = self.extract_ignore_idx(query_cmask, class_sample_query,
                                                                       purple=self.purple)
            if self.mask_transform:
                query_mask = self.mask_transform(query_mask)

            if self.image_transform:
                support_img = self.image_transform(support_img)
            support_mask, support_ignore_idx = self.extract_ignore_idx(support_cmask, class_sample_support,
                                                                       purple=self.purple)

            if self.mask_transform:
                support_mask = self.mask_transform(support_mask)

            grid = self.create_gradiant_grid_images(support_img, support_mask, query_img, query_mask, self.arr)

            if len(grid_stack) == 0:
                grid_stack = grid
            else:
                grid_stack = torch.cat((grid_stack, grid))

            shot_batch = {'query_img': query_img,
                          'query_mask': query_mask,
                          'support_img': support_img,
                          'support_mask': support_mask,
                          'support_name': support_name,
                          'suppor_class': class_sample_support,
                          'query_name_val': query_2_find_support_name,
                          'support_name_datastore': support_name_datastore,
                          'query_class_datastore': query_class_datastore,
                          'grid_stack': grid_stack
                          }
            batch_list.append(shot_batch)

        return batch_list

    def extract_ignore_idx(self, mask, class_id, purple):
        mask = np.array(mask)
        boundary = np.floor(mask / 255.)
        if not purple:
            mask[mask != class_id + 1] = 0
            mask[mask == class_id + 1] = 255
            return Image.fromarray(mask), boundary

    def extract_ignore_idx_coco(self, mask, class_id, purple):
        mask = np.array(mask)
        boundary = np.floor(mask / 255.)
        if not purple:
            mask[mask != class_id] = 0
            mask[mask == class_id] = 255
            return Image.fromarray(mask), boundary

    def load_frame(self, query_name, support_name):
        # import pdb;pdb.set_trace()
        query_img = self.read_coco_img(query_name)
        query_mask = self.read_coco_mask(query_name)
        support_img = self.read_coco_img(support_name)
        support_mask = self.read_coco_mask(support_name)
        org_qry_imsize = query_img.size
        support_img = support_img.convert('RGB')

        return query_img, query_mask, support_img, support_mask, org_qry_imsize

    def read_pascal_mask(self, img_name):
        r"""Return segmentation mask in PIL Image"""
        mask = Image.open(os.path.join(self.pascal_ann_path, img_name) + '.png')
        return mask

    def read_pascal_img(self, img_name):
        r"""Return RGB image in PIL Image"""
        return Image.open(os.path.join(self.pascal_img_path, img_name) + '.jpg')

    def read_coco_mask(self, img_name):
        r"""Return segmentation mask in PIL Image"""
        mask = Image.open(os.path.join(self.coco_ann_path, img_name) + '.png')
        return mask

    def read_coco_img(self, img_name):
        r"""Return RGB image in PIL Image"""
        return Image.open(os.path.join(self.coco_img_path, img_name) + '.jpg')

    def sample_episode(self, idx, sim_idx):
        """Returns the index of the query, support and class."""
        query_name, class_sample = self.img_metadata_val[idx]
        support_name = self.images_top50_val[query_name]['top50'][sim_idx]
        support_class = self.images_top50_trn[support_name]['class']

        if support_name == query_name:
            print('support_name = query_name ' + support_name)
            return self.sample_episode(idx, sim_idx + 1)

        return query_name, support_name, class_sample, support_class

    def sample_episode_for_training_retriever_itself(self, idx, sim_idx):
        """The default configuration of PANICL"""
        query_2_find_support_name, query_2_find_support_name_class_sample = self.img_metadata_val[idx]
        support_name = self.images_top50_val[query_2_find_support_name]['top50'][sim_idx]
        class_sample = self.images_top50_trn[support_name]['class']
        support_name_datastore = self.images_top50_val[query_2_find_support_name]['top50'][0]
        query_name = support_name
        support_class = class_sample
        query_class_datastore = query_2_find_support_name_class_sample

        return query_name, support_name, class_sample, support_class, query_2_find_support_name, support_name_datastore, query_class_datastore

    def build_class_ids(self):
        nclass_trn = self.nclass // self.nfolds
        class_ids_val = [self.fold * nclass_trn + i for i in range(nclass_trn)]
        return class_ids_val

    def build_img_metadata(self, split):

        def read_metadata(split, fold_id):
            cwd = './evaluate'

            if self.cluster:
                fold_n_metadata_path = os.path.join(cwd, 'splits/pascal/%s/fold_cluster%d.txt' % (split, fold_id))
            else:
                fold_n_metadata_path = os.path.join(cwd, 'splits/pascal/%s/fold%d.txt' % (split, fold_id))

            with open(fold_n_metadata_path, 'r') as f:
                fold_n_metadata = f.read().split('\n')[:-1]

            if self.cluster:
                fold_n_metadata = [[data.split('__')[0], int(data.split('__')[1]) - 1, int(data.split('__')[2]) - 1] for
                                   data in fold_n_metadata]
            else:
                fold_n_metadata = [[data.split('__')[0], int(data.split('__')[1]) - 1] for data in fold_n_metadata]

            return fold_n_metadata

        img_metadata = read_metadata(split, self.fold)

        print('Total (%s) images are : %d' % (split, len(img_metadata)))

        return img_metadata

    def build_coco_img_metadata(self, split):

        def read_metadata(split, fold_id):
            cwd = './tools'

            fold_n_metadata_path = os.path.join(cwd, f'coco/{split}_/fold{fold_id}.txt')

            with open(fold_n_metadata_path, 'r') as f:
                fold_n_metadata = f.read().split('\n')[:-1]

            fold_n_metadata = [[data.split('__')[0], int(data.split('__')[1]) - 1] for data in fold_n_metadata]

            return fold_n_metadata

        img_metadata = read_metadata(split, self.fold)

        print('Total (%s) images are : %d' % (split, len(img_metadata)))

        return img_metadata
