from pathlib import Path
import time

import cv2
import lmdb
import numpy as np
import torch as pt
import torch.utils.data as ptud

from .dataset import compress, decompress
from .utils import (
    color_segment_to_index_segment_and_bbox,
    even_resize_and_center_crop,
    normaliz_for_visualiz,
    draw_segmentation_np,
)


class PascalVOC(ptud.Dataset):
    """Visual Object Classes Challenge 2012 (VOC2012) + 2007
    http://host.robots.ox.ac.uk/pascal/VOC/voc2012/index.html#devkit
    http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar
    http://host.robots.ox.ac.uk/pascal/VOC/voc2007/index.html#devkit
    http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar

    Number of objects distribution:
    - train: {1: 1374, 2: 638, 3: 303, 4: 177, 5: 103, 6: 67, 7: 44, 8: 28, 9: 17, 10: 12}
    - val: {1: 129, 2: 126, 3: 55, 4: 33, 5: 23, 6: 17, 7: 8, 8: 4, 9: 4, 10: 2}
    """

    def __init__(
        self,
        data_file,
        transform=lambda **_: _,
        max_spare=4,
        base_dir: Path = None,
    ):
        if base_dir:
            data_file = base_dir / data_file
        self.env = lmdb.open(
            str(data_file),
            subdir=False,
            readonly=True,
            readahead=False,
            meminit=False,
            max_spare_txns=max_spare,
            lock=False,
        )
        with self.env.begin(write=False) as txn:
            self.keys = decompress(txn.get(b"__keys__"))
        self.transform = transform

    def __getitem__(self, index):
        """
        - image: in shape (c=3,h,w), float32
        - bbox: in shape (n,c=4), float32
        - smask: in shape (n,), bool, masking slots
        - segment: in shape (h,w), uint8
        """
        with self.env.begin(write=False) as txn:
            sample = decompress(txn.get(self.keys[index]))
        n, c = sample["bbox"].shape
        sample = dict(
            image=pt.from_numpy(sample["image"]).permute(2, 0, 1),
            bbox=pt.from_numpy(sample["bbox"]),
            smask=pt.ones([n], dtype=pt.bool),
            segment=pt.from_numpy(sample["segment"]),
        )
        sample2 = self.transform(**sample)
        return sample2

    def __len__(self):
        return len(self.keys)

    @staticmethod
    def convert_dataset(
        src_dir=Path("/media/GeneralZ/Storage/Static/datasets/pascalvoc/VOCdevkit"),
        dst_dir=Path("voc"),
        resolut=(128, 128),  # spatial downsample
        max_nobj=10,  # images containing more objects will be discarded
        min_area=250,  # images containing largest area smaller than this will be discarded
    ):
        """
        Structure dataset as follows and run it!
        - VOC2012/JPEGImages
          *.jpg
        - VOC2012/SegmentationObject
          *.png
        - VOC2012/ImageSets/Segmentation
          *.txt
        - VOC2007/JPEGImages
          *.jpg
        - VOC2007/SegmentationObject
          *.png
        - VOC2007/ImageSets/Segmentation
          *.txt
        """
        dst_dir.mkdir(parents=True, exist_ok=True)
        info_file = dst_dir / f"{resolut[0]}_{resolut[1]}-{max_nobj}"
        info_file.touch()
        assert resolut[0] == resolut[1]
        side = resolut[0]

        splits = dict(
            train=[
                "VOC2012/ImageSets/Segmentation/trainval.txt",
                "VOC2012/JPEGImages",
                "VOC2012/SegmentationObject",
            ],
            val=[
                "VOC2007/ImageSets/Segmentation/trainval.txt",
                "VOC2007/JPEGImages",
                "VOC2007/SegmentationObject",
            ],
        )

        for split, [split_fn, image_dn, segment_dn] in splits.items():
            split_file = src_dir / split_fn
            image_path = src_dir / image_dn
            segment_path = src_dir / segment_dn

            with open(split_file, "r") as f:
                lines = f.readlines()
            xample_fns = [_.strip() for _ in lines]

            dst_file = dst_dir / f"{split}.lmdb"
            lmdb_env = lmdb.open(
                str(dst_file),
                map_size=1024**4,
                subdir=False,
                readonly=False,
                meminit=False,
            )
            keys = []
            txn = lmdb_env.begin(write=True)

            t0 = time.time()

            cnt = 0
            for i, xample_fn in enumerate(xample_fns):
                image_file = image_path / f"{xample_fn}.jpg"
                segment_file = segment_path / f"{xample_fn}.png"
                image0 = cv2.imread(str(image_file))
                segment0 = cv2.imread(str(segment_file))

                height, width = image0.shape[:2]
                ar = height / width  # aspect ratio
                if ar <= 2 / 3 or ar >= 3 / 2:
                    side2 = min([height, width])
                    top = (height - side2) // 2
                    left = (width - side2) // 2
                    crops_i = [
                        image0[:side2, :side2],
                        image0[top : top + side2, left : left + side2],
                        image0[-side2:, -side2:],
                    ]
                    crops_s = [
                        segment0[:side2, :side2],
                        segment0[top : top + side2, left : left + side2],
                        segment0[-side2:, -side2:],
                    ]
                else:
                    crops_i = [image0]
                    crops_s = [segment0]

                for image, segment in zip(crops_i, crops_s):
                    image = even_resize_and_center_crop(image, side)
                    segment = even_resize_and_center_crop(
                        segment, side, cv2.INTER_NEAREST_EXACT
                    )
                    segment = np.where(
                        segment == np.array([[[192, 224, 224]]]), 0, segment
                    )
                    segment, bbox = color_segment_to_index_segment_and_bbox(segment)

                    if bbox.shape[0] > max_nobj:
                        continue
                    adict = dict(zip(*np.unique(segment, return_counts=True)))
                    adict.pop(0)  # remove background
                    if len(adict) == 0 or np.max(list(adict.values())) < min_area:
                        continue
                    cnt += 1

                    # __class__.visualiz(image, bbox / side, segment, 0)

                    sample_key = f"{cnt:06d}".encode("ascii")
                    keys.append(sample_key)

                    sample_dict = dict(
                        image=image,  # (h,w,c=3)
                        bbox=bbox / side,  # (n,c=4)
                        segment=segment,  # (h,w)
                    )
                    txn.put(sample_key, compress(sample_dict))

                    if (cnt + 1) % 64 == 0:  # write_freq
                        print(f"{(cnt+1):06d}")
                        txn.commit()
                        txn = lmdb_env.begin(write=True)

            txn.commit()
            txn = lmdb_env.begin(write=True)
            txn.put(b"__keys__", compress(keys))
            txn.commit()
            lmdb_env.close()

            print(f"total={cnt}, time={time.time() - t0}")

    @staticmethod
    def visualiz(image, bbox=None, segment=None, wait=0):
        if isinstance(image, pt.Tensor):
            image = image.permute(1, 2, 0).cpu().contiguous().numpy()
        image = normaliz_for_visualiz(image)
        image = np.clip(image * 127.5 + 127.5, 0, 255).astype("uint8")

        if bbox is not None:
            if isinstance(bbox, pt.Tensor):
                bbox = bbox.cpu().numpy()
            bbox[:, 0::2] *= image.shape[1]
            bbox[:, 1::2] *= image.shape[0]
            for box in bbox.astype("int"):
                image = cv2.rectangle(image, box[:2], box[2:], (0, 0, 0))

        cv2.imshow("i", image)

        if segment is not None:
            if isinstance(segment, pt.Tensor):
                segment = segment.cpu().numpy()
            segment = draw_segmentation_np(image, segment, 0.6)
            cv2.imshow("s", segment)

        cv2.waitKey(wait)
        return image, segment


# def extract_instance_segmentation(image):
#     # Convert image to grayscale
#     gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)

#     # Threshold the image to extract thick contours
#     _, thresh = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY)

#     # Find contours
#     contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

#     # Create an empty mask to draw the instances
#     instance_mask = np.zeros_like(gray)

#     # Draw contours on the mask with different integer numbers for each instance
#     for i, contour in enumerate(contours):
#         cv2.drawContours(instance_mask, [contour], 0, (i + 1), -1)

#     return instance_mask
