from functools import lru_cache
import sys

assert sys.version.startswith('3'), "Python version >= 3.8"
from PIL import Image
import numpy as np
import os, json
import logging

import torch, torchvision
import torch.utils.data as tdata

from . import data_utils
from utils import config

import pickle as pkl


class ImageDataset(tdata.Dataset):
    def __init__(self, phase, ds_name):
        super(ImageDataset, self).__init__()
        self.logger = logging.getLogger(f"Dataset {ds_name}")
        self.phase = phase
        self.ds_name = ds_name
        self.transform = data_utils.imagenet_transform(phase)

        # reading annotation pickle
        path = './data/OCKB_%s.pkl'%phase
        self.logger.info(path)
        self.pkl_data = pkl.load(open(path,'rb'))
        
        self.instance_indices = [(i, j) for i, img in enumerate(self.pkl_data) for j in range(len(img['objects']))]
        self.logger.info(f"{len(self.instance_indices)} instances")

        # construct attr/aff/obj list/matrix
        self.load_category_lists(ds_name)


        logging.info("#obj %d, #attr %d, #aff %d" % (
            self.num_obj, self.num_attr, self.num_aff))

    def load_category_lists(self, ds_name):
        """Read obj/attr/aff category lists, as well as seen_obj_mask"""

        def load_and_filter(path):
            f = json.load(open(path))
            return list(map(lambda x: x[1], f))

        self.attrs = load_and_filter(os.path.join(config.ROOT_DIR, f"utils/aux_data/{ds_name}_attrs.json"))
        self.objs = load_and_filter(
            os.path.join(config.ROOT_DIR, f"utils/aux_data/{ds_name}_objs.json"))
        self.obj2id = {x: i for i, x in enumerate(self.objs)}
        self.affs = json.load(open(os.path.join(config.ROOT_DIR, f"utils/aux_data/{ds_name}_affs.json")))

        self.num_aff = len(self.affs)
        self.num_attr = len(self.attrs)
        self.num_obj = len(self.objs)


    def __len__(self):
        return len(self.pkl_data)

    def __getitem__(self, index):
        info = self.pkl_data[index]
        objects = info["objects"]
        file_path = info['file_path']
        image = Image.open(file_path).convert('RGB')
        image_width, image_height = image.size
        do_downsampling = (max(image_width, image_height) > 1800)
        if do_downsampling:
            resize_trans = torchvision.transforms.Resize((image_width // 2, image_height // 2))
            image = resize_trans(image)
        image = self.transform(image)

        attr = np.stack([ obj['attr'] for obj in objects], axis=0)
        attr = (attr > 0.5).astype(np.float32)

        gt_box = [(obj['bbox'] if 'bbox' in obj else [0, 0, image_width, image_height]) for obj in objects]

        if len(gt_box) > 0 and type(gt_box[0]) == np.ndarray:
            gt_box = [x.squeeze().tolist() for x in gt_box]

        if do_downsampling:
            gt_box = [(x[0] // 2, x[1] // 2, x[2] // 2, x[3] // 2) for x in gt_box]

        aff = np.stack([obj['aff'] for obj in objects], axis=0)
        aff = (aff > 0.5).astype(np.float32)

        sample = {
            "image": image,
            "file_path": file_path,
            "gt_bbox": np.array(gt_box, dtype=np.float32),
            'gt_aff': aff,
            "gt_attr": attr,
            'gt_obj_id': np.array([self.obj2id[obj['ockb_obj_name']] for obj in objects]),
            'val_mask': np.array([info['split'] == 'valid' for _ in objects])
        }

        sample["main_bbox"] = sample["gt_bbox"]

        return sample


class FeatureDataset(ImageDataset):
    def __init__(self, phase, ds_name, feature_dir):
        super(FeatureDataset, self).__init__(phase, ds_name)
        assert ds_name in ['OCKB'] or ds_name.startswith("Demo")

        # reading feature .t7
        feature_path = os.path.join(config.ROOT_DIR, feature_dir, f"{phase}.t7")
        logging.info("reading " + feature_path)

        # load features into memory
        self.features_list, self.feature_dim = data_utils.features_loader(feature_path, self.pkl_data)

    def __len__(self):
        return len(self.instance_indices)

    @lru_cache(maxsize=None)
    def __getitem__(self, index):
        imgId, instId = self.instance_indices[index]
        img = self.pkl_data[imgId]
        obj = img['objects'][instId]

        feature = self.features_list[imgId][instId, ...]

        obj_id = self.obj2id[obj['ockb_obj_name']]

        attr = np.array(obj['attr'])
        attr = (attr > 0.5).astype(np.float32)

        aff = np.array(obj['aff'])
        aff = (aff > 0.5).astype(np.float32)

        sample = {
            "image": feature,
            "gt_attr": attr,
            'gt_obj_id': np.array(obj_id, dtype=int),
            'gt_aff': aff,
        }

        # add val/test mask
        sample['val_mask'] = np.array(img['split'] == "valid")
        sample['gt_causal'] = np.array(obj['causal'],dtype=int) if 'causal' in obj else []

        return sample
