import os
import os.path as op
import json

# import logging
import base64
import yaml
import errno
import io
import math
from PIL import Image, ImageDraw
import torch
import numpy as np
import collections
import pycocotools.mask as mask_utils


def load_linelist_file(linelist_file):
    if linelist_file is not None:
        line_list = []
        with open(linelist_file, "r") as fp:
            for i in fp:
                line_list.append(int(i.strip()))
        return line_list


def img_from_base64(imagestring):
    try:
        img = Image.open(io.BytesIO(base64.b64decode(imagestring)))
        return img.convert("RGB")
    except ValueError:
        return None


def load_from_yaml_file(yaml_file):
    with open(yaml_file, "r") as fp:
        return yaml.load(fp, Loader=yaml.CLoader)


def find_file_path_in_yaml(fname, root):
    if fname is not None:
        found_file = None
        if op.isfile(fname):
            found_file = fname
        elif op.isfile(op.join(root, fname)):
            found_file = op.join(root, fname)
        else:
            # be a bit more flexible and try to find the file in the root recursively
            try_time = 3
            while try_time > 0:
                try_time -= 1
                root = os.path.dirname(root)
                if op.isfile(op.join(root, fname)):
                    found_file = op.join(root, fname)
                    break
        if found_file is None:
            raise FileNotFoundError(
                errno.ENOENT, os.strerror(errno.ENOENT), op.join(root, fname)
            )
        print('found file: {}'.format(found_file))
        return found_file

def create_lineidx(filein, idxout):
    idxout_tmp = idxout + ".tmp"
    with open(filein, "r") as tsvin, open(idxout_tmp, "w") as tsvout:
        fsize = os.fstat(tsvin.fileno()).st_size
        fpos = 0
        while fpos != fsize:
            tsvout.write(str(fpos) + "\n")
            tsvin.readline()
            fpos = tsvin.tell()
    os.rename(idxout_tmp, idxout)


def read_to_character(fp, c):
    result = []
    while True:
        s = fp.read(32)
        assert s != ""
        if c in s:
            result.append(s[: s.index(c)])
            break
        else:
            result.append(s)
    return "".join(result)


class TSVFile(object):
    def __init__(self, tsv_file, generate_lineidx=False):
        self.tsv_file = tsv_file
        self.lineidx = op.splitext(tsv_file)[0] + ".lineidx"
        self._fp = None
        self._lineidx = None
        # the process always keeps the process which opens the file.
        # If the pid is not equal to the currrent pid, we will re-open the file.
        self.pid = None
        # generate lineidx if not exist
        if not op.isfile(self.lineidx) and generate_lineidx:
            create_lineidx(self.tsv_file, self.lineidx)

    def __del__(self):
        if self._fp:
            self._fp.close()

    def __str__(self):
        return "TSVFile(tsv_file='{}')".format(self.tsv_file)

    def __repr__(self):
        return str(self)

    def num_rows(self):
        self._ensure_lineidx_loaded()
        return len(self._lineidx)

    def seek(self, idx):
        self._ensure_tsv_opened()
        self._ensure_lineidx_loaded()
        try:
            pos = self._lineidx[idx]
        except:
            # logging.info('{}-{}'.format(self.tsv_file, idx))
            raise
        self._fp.seek(pos)
        return [s.strip() for s in self._fp.readline().split("\t")]

    def seek_first_column(self, idx):
        self._ensure_tsv_opened()
        self._ensure_lineidx_loaded()
        pos = self._lineidx[idx]
        self._fp.seek(pos)
        return read_to_character(self._fp, "\t")

    def get_key(self, idx):
        return self.seek_first_column(idx)

    def __getitem__(self, index):
        return self.seek(index)

    def __len__(self):
        return self.num_rows()

    def _ensure_lineidx_loaded(self):
        if self._lineidx is None:
            # logging.info('loading lineidx: {}'.format(self.lineidx))
            with open(self.lineidx, "r") as fp:
                self._lineidx = [int(i.strip()) for i in fp.readlines()]

    def _ensure_tsv_opened(self):
        if self._fp is None:
            self._fp = open(self.tsv_file, "r")
            self.pid = os.getpid()

        if self.pid != os.getpid():
            # logging.info('re-open {} because the process id changed'.format(self.tsv_file))
            self._fp = open(self.tsv_file, "r")
            self.pid = os.getpid()


class CompositeTSVFile:
    def __init__(self, file_list, seq_file, root="."):
        if isinstance(file_list, str):
            self.file_list = load_list_file(file_list)
        else:
            assert isinstance(file_list, list)
            self.file_list = file_list

        self.seq_file = seq_file
        self.root = root
        self.initialized = False
        self.initialize()

    def get_key(self, index):
        idx_source, idx_row = self.seq[index]
        k = self.tsvs[idx_source].get_key(idx_row)
        return "_".join([self.file_list[idx_source], k])

    def num_rows(self):
        return len(self.seq)

    def __getitem__(self, index):
        idx_source, idx_row = self.seq[index]
        return self.tsvs[idx_source].seek(idx_row)

    def __len__(self):
        return len(self.seq)

    def initialize(self):
        """
        this function has to be called in init function if cache_policy is
        enabled. Thus, let's always call it in init funciton to make it simple.
        """
        if self.initialized:
            return
        self.seq = []
        with open(self.seq_file, "r") as fp:
            for line in fp:
                parts = line.strip().split("\t")
                self.seq.append([int(parts[0]), int(parts[1])])
        self.tsvs = [TSVFile(op.join(self.root, f)) for f in self.file_list]
        self.initialized = True


def load_list_file(fname):
    with open(fname, "r") as fp:
        lines = fp.readlines()
    result = [line.strip() for line in lines]
    if len(result) > 0 and result[-1] == "":
        result = result[:-1]
    return result


class TSVDataset(object):
    def __init__(self, img_file, label_file=None, hw_file=None, linelist_file=None, imageid2idx_file=None):
        """Constructor.
        Args:
            img_file: Image file with image key and base64 encoded image str.
            label_file: An optional label file with image key and label information.
                A label_file is required for training and optional for testing.
            hw_file: An optional file with image key and image height/width info.
            linelist_file: An optional file with a list of line indexes to load samples.
                It is useful to select a subset of samples or duplicate samples.
        """
        self.img_file = img_file
        self.label_file = label_file
        self.hw_file = hw_file
        self.linelist_file = linelist_file

        self.img_tsv = TSVFile(img_file)
        self.label_tsv = None if label_file is None else TSVFile(label_file, generate_lineidx=True)
        self.hw_tsv = None if hw_file is None else TSVFile(hw_file)
        self.line_list = load_linelist_file(linelist_file)
        self.imageid2idx = None
        if imageid2idx_file is not None:
            self.imageid2idx = json.load(open(imageid2idx_file, "r"))

        self.transforms = None

    def __len__(self):
        if self.line_list is None:
            if self.imageid2idx is not None:
                assert self.label_tsv is not None, "label_tsv is None!!!"
                return self.label_tsv.num_rows()
            return self.img_tsv.num_rows()
        else:
            return len(self.line_list)

    def __getitem__(self, idx):
        img = self.get_image(idx)
        img_size = img.size  # w, h
        annotations = self.get_annotations(idx)
        # print(idx, annotations)
        target = self.get_target_from_annotations(annotations, img_size, idx)
        img, target = self.apply_transforms(img, target)

        if self.transforms is None:
            return img, target, idx, 1.0
        else:
            new_img_size = img.shape[1:]
            scale = math.sqrt(float(new_img_size[0] * new_img_size[1]) / float(img_size[0] * img_size[1]))
            return img, target, idx, scale

    def get_line_no(self, idx):
        return idx if self.line_list is None else self.line_list[idx]

    def get_image(self, idx):
        line_no = self.get_line_no(idx)
        if self.imageid2idx is not None:
            assert self.label_tsv is not None, "label_tsv is None!!!"
            row = self.label_tsv.seek(line_no)
            annotations = json.loads(row[1])
            imageid = annotations["img_id"]
            line_no = self.imageid2idx[imageid]
        row = self.img_tsv.seek(line_no)
        # use -1 to support old format with multiple columns.
        img = img_from_base64(row[-1])
        return img

    def get_annotations(self, idx):
        line_no = self.get_line_no(idx)
        if self.label_tsv is not None:
            row = self.label_tsv.seek(line_no)
            annotations = json.loads(row[1])
            return annotations
        else:
            return []

    def get_target_from_annotations(self, annotations, img_size, idx):
        # This function will be overwritten by each dataset to
        # decode the labels to specific formats for each task.
        return annotations

    def apply_transforms(self, image, target=None):
        # This function will be overwritten by each dataset to
        # apply transforms to image and targets.
        return image, target

    def get_img_info(self, idx):
        if self.imageid2idx is not None:
            assert self.label_tsv is not None, "label_tsv is None!!!"
            line_no = self.get_line_no(idx)
            row = self.label_tsv.seek(line_no)
            annotations = json.loads(row[1])
            return {"height": int(annotations["img_w"]), "width": int(annotations["img_w"])}

        if self.hw_tsv is not None:
            line_no = self.get_line_no(idx)
            row = self.hw_tsv.seek(line_no)
            try:
                # json string format with "height" and "width" being the keys
                data = json.loads(row[1])
                if type(data) == list:
                    return data[0]
                elif type(data) == dict:
                    return data
            except ValueError:
                # list of strings representing height and width in order
                hw_str = row[1].split(" ")
                hw_dict = {"height": int(hw_str[0]), "width": int(hw_str[1])}
                return hw_dict

    def get_img_key(self, idx):
        line_no = self.get_line_no(idx)
        # based on the overhead of reading each row.
        if self.imageid2idx is not None:
            assert self.label_tsv is not None, "label_tsv is None!!!"
            row = self.label_tsv.seek(line_no)
            annotations = json.loads(row[1])
            return annotations["img_id"]

        if self.hw_tsv:
            return self.hw_tsv.seek(line_no)[0]
        elif self.label_tsv:
            return self.label_tsv.seek(line_no)[0]
        else:
            return self.img_tsv.seek(line_no)[0]


class TSVYamlDataset(TSVDataset):
    """TSVDataset taking a Yaml file for easy function call"""

    def __init__(self, yaml_file, root=None, replace_clean_label=False):
        print("Reading {}".format(yaml_file))
        self.cfg = load_from_yaml_file(yaml_file)
        if root:
            self.root = root
        else:
            self.root = op.dirname(yaml_file)
        img_file = find_file_path_in_yaml(self.cfg["img"], self.root)
        label_file = find_file_path_in_yaml(self.cfg.get("label", None), self.root)
        hw_file = find_file_path_in_yaml(self.cfg.get("hw", None), self.root)
        linelist_file = find_file_path_in_yaml(self.cfg.get("linelist", None), self.root)
        imageid2idx_file = find_file_path_in_yaml(self.cfg.get("imageid2idx", None), self.root)

        if replace_clean_label:
            assert "raw_label" in label_file
            label_file = label_file.replace("raw_label", "clean_label")

        super(TSVYamlDataset, self).__init__(img_file, label_file, hw_file, linelist_file, imageid2idx_file)


class ODTSVDataset(TSVYamlDataset):
    """
    Generic TSV dataset format for Object Detection.
    """

    def __init__(self, yaml_file, extra_fields=(), transforms=None, is_load_label=True, **kwargs):
        if yaml_file is None:
            return
        super(ODTSVDataset, self).__init__(yaml_file)

        self.transforms = transforms
        self.is_load_label = is_load_label
        self.attribute_on = False
        # self.attribute_on = kwargs['args'].MODEL.ATTRIBUTE_ON if "args" in kwargs else False

        if self.is_load_label:
            # construct maps
            jsondict_file = find_file_path_in_yaml(self.cfg.get("labelmap", None), self.root)
            if jsondict_file is None:
                jsondict_file = find_file_path_in_yaml(self.cfg.get("jsondict", None), self.root)
            if "json" in jsondict_file:
                jsondict = json.load(open(jsondict_file, "r"))
                if "label_to_idx" not in jsondict:
                    jsondict = {"label_to_idx": jsondict}
            elif "tsv" in jsondict_file:
                label_to_idx = {}
                counter = 1
                with open(jsondict_file) as f:
                    for line in f:
                        label_to_idx[line.strip()] = counter
                        counter += 1
                jsondict = {"label_to_idx": label_to_idx}
            else:
                assert 0

            self.labelmap = {}
            self.class_to_ind = jsondict["label_to_idx"]
            self.class_to_ind["__background__"] = 0
            self.ind_to_class = {v: k for k, v in self.class_to_ind.items()}
            self.labelmap["class_to_ind"] = self.class_to_ind

            if self.attribute_on:
                self.attribute_to_ind = jsondict["attribute_to_idx"]
                self.attribute_to_ind["__no_attribute__"] = 0
                self.ind_to_attribute = {v: k for k, v in self.attribute_to_ind.items()}
                self.labelmap["attribute_to_ind"] = self.attribute_to_ind

            self.label_loader = LabelLoader(
                labelmap=self.labelmap,
                extra_fields=extra_fields,
            )

    def get_target_from_annotations(self, annotations, img_size, idx):
        if isinstance(annotations, list):
            annotations = {"objects": annotations}
        if self.is_load_label:
            return self.label_loader(annotations["objects"], img_size)

    def apply_transforms(self, img, target=None):
        if self.transforms is not None:
            img, target = self.transforms(img, target)
        return img, target
    
class LabelLoader(object):
    def __init__(
        self,
        labelmap,
        extra_fields=(),
        filter_duplicate_relations=False,
        ignore_attr=None,
        ignore_rel=None,
        mask_mode="poly",
    ):
        self.labelmap = labelmap
        self.extra_fields = extra_fields
        self.supported_fields = ["class", "conf", "attributes", "scores_all", "boxes_all", "feature", "mask"]
        self.filter_duplicate_relations = filter_duplicate_relations
        self.ignore_attr = set(ignore_attr) if ignore_attr != None else set()
        self.ignore_rel = set(ignore_rel) if ignore_rel != None else set()
        assert mask_mode == "poly" or mask_mode == "mask"
        self.mask_mode = mask_mode

    def __call__(self, annotations, img_size, remove_empty=False, load_fields=None):
        boxes = [obj["rect"] for obj in annotations]
        #boxes = torch.as_tensor(boxes).reshape(-1, 4)
        #target = BoxList(boxes, img_size, mode="xyxy")
        target = torch.as_tensor(boxes).reshape(-1, 4)
        return target

    def add_anno_id(self, annotations):
        anno_id = []
        for obj in annotations:
            anno_id.append(obj["id"])
        return torch.tensor(anno_id)

    def get_box_mask(self, rect, img_size):
        x1, y1, x2, y2 = rect[0], rect[1], rect[2], rect[3]
        if self.mask_mode == "poly":
            return [[x1, y1, x1, y2, x2, y2, x2, y1]]
        elif self.mask_mode == "mask":
            # note the order of height/width order in mask is opposite to image
            mask = np.zeros([img_size[1], img_size[0]], dtype=np.uint8)
            mask[math.floor(y1) : math.ceil(y2), math.floor(x1) : math.ceil(x2)] = 255
            encoded_mask = mask_utils.encode(np.asfortranarray(mask))
            encoded_mask["counts"] = encoded_mask["counts"].decode("utf-8")
            return encoded_mask

    def add_masks(self, annotations, img_size):
        masks = []
        is_box_mask = []
        for obj in annotations:
            if "mask" in obj:
                masks.append(obj["mask"])
                is_box_mask.append(0)
            else:
                masks.append(self.get_box_mask(obj["rect"], img_size))
                is_box_mask.append(1)
        masks = SegmentationMask(masks, img_size, mode=self.mask_mode)
        is_box_mask = torch.tensor(is_box_mask)
        return masks, is_box_mask

    def add_classes(self, annotations):
        class_names = [obj["class"] for obj in annotations]
        classes = [None] * len(class_names)
        for i in range(len(class_names)):
            classes[i] = self.labelmap["class_to_ind"][class_names[i]]
        return torch.tensor(classes)

    def add_confidences(self, annotations):
        confidences = []
        for obj in annotations:
            if "conf" in obj:
                confidences.append(obj["conf"])
            else:
                confidences.append(1.0)
        return torch.tensor(confidences)

    def add_attributes(self, annotations):
        # the maximal number of attributes per object is 16
        attributes = [[0] * 16 for _ in range(len(annotations))]
        for i, obj in enumerate(annotations):
            for j, attr in enumerate(obj["attributes"]):
                attributes[i][j] = self.labelmap["attribute_to_ind"][attr]
        return torch.tensor(attributes)

    def add_features(self, annotations):
        features = []
        for obj in annotations:
            features.append(np.frombuffer(base64.b64decode(obj["feature"]), np.float32))
        return torch.tensor(features)

    def add_scores_all(self, annotations):
        scores_all = []
        for obj in annotations:
            scores_all.append(np.frombuffer(base64.b64decode(obj["scores_all"]), np.float32))
        return torch.tensor(scores_all)

    def add_boxes_all(self, annotations):
        boxes_all = []
        for obj in annotations:
            boxes_all.append(np.frombuffer(base64.b64decode(obj["boxes_all"]), np.float32).reshape(-1, 4))
        return torch.tensor(boxes_all)

    def relation_loader(self, relation_annos, target):
        if self.filter_duplicate_relations:
            # Filter out dupes!
            all_rel_sets = collections.defaultdict(list)
            for triplet in relation_annos:
                all_rel_sets[(triplet["subj_id"], triplet["obj_id"])].append(triplet)
            relation_annos = [np.random.choice(v) for v in all_rel_sets.values()]

        # get M*M pred_labels
        relation_triplets = []
        relations = torch.zeros([len(target), len(target)], dtype=torch.int64)
        for i in range(len(relation_annos)):
            if len(self.ignore_rel) != 0 and relation_annos[i]["class"] in self.ignore_rel:
                continue
            subj_id = relation_annos[i]["subj_id"]
            obj_id = relation_annos[i]["obj_id"]
            predicate = self.labelmap["relation_to_ind"][relation_annos[i]["class"]]
            relations[subj_id, obj_id] = predicate
            relation_triplets.append([subj_id, obj_id, predicate])

        relation_triplets = torch.tensor(relation_triplets)
        target.add_field("relation_labels", relation_triplets)
        target.add_field("pred_labels", relations)
        return target