import os
import torch
import torch.nn.functional as F
import torchvision.transforms.functional_tensor as F_t
import torchvision.transforms.functional as TF
import numpy as np
import imageio
import cv2
import skimage

#  from util import GaussianBlur


class ColorJitterDataset(torch.utils.data.Dataset):
    def __init__(
        self,
        base_dset,
        hue_range=0.1,
        saturation_range=0.1,
        brightness_range=0.1,
        contrast_range=0.1,
        extra_inherit_attrs=[],
    ):
        self.hue_range = [-hue_range, hue_range]
        self.saturation_range = [1 - saturation_range, 1 + saturation_range]
        self.brightness_range = [1 - brightness_range, 1 + brightness_range]
        self.contrast_range = [1 - contrast_range, 1 + contrast_range]
        inherit_attrs = ["z_near", "z_far", "lindisp", "base_path", "image_to_tensor"]
        inherit_attrs.extend(extra_inherit_attrs)

        self.base_dset = base_dset
        for inherit_attr in inherit_attrs:
            setattr(self, inherit_attr, getattr(self.base_dset, inherit_attr))

    def apply_color_jitter(self, images):
        # apply the same color jitter over batch of images
        hue_factor = np.random.uniform(*self.hue_range)
        saturation_factor = np.random.uniform(*self.saturation_range)
        brightness_factor = np.random.uniform(*self.brightness_range)
        contrast_factor = np.random.uniform(*self.contrast_range)
        for i in range(len(images)):
            tmp = (images[i] + 1.0) * 0.5
            tmp = F_t.adjust_saturation(tmp, saturation_factor)
            tmp = F_t.adjust_hue(tmp, hue_factor)
            tmp = F_t.adjust_contrast(tmp, contrast_factor)
            tmp = F_t.adjust_brightness(tmp, brightness_factor)
            images[i] = tmp * 2.0 - 1.0
        return images

    def __len__(self):
        return len(self.base_dset)

    def __getitem__(self, idx):
        data = self.base_dset[idx]
        data["images"] = self.apply_color_jitter(data["images"])
        return data


def square_crop_img(img):
    min_dim = np.amin(img.shape[:2])
    center_coord = np.array(img.shape[:2]) // 2
    img = img[center_coord[0] - min_dim // 2:center_coord[0] + min_dim // 2,
          center_coord[1] - min_dim // 2:center_coord[1] + min_dim // 2]
    return img

def load_rgb(path, sidelength=None):
    img = imageio.imread(path)
    alpha_channel = img[:, :, 3:].repeat(4, axis=2)
    img = img[:, :, :4]
    img = skimage.img_as_float32(img)
    img[alpha_channel == 0] = 1.0

    img = square_crop_img(img)

    if sidelength is not None:
        img = cv2.resize(img, (sidelength, sidelength), interpolation=cv2.INTER_AREA)

    img -= 0.5
    img *= 2.
    img = img.transpose(2, 0, 1)

    return img

def load_seg(path, sidelength=None):
    img = np.load(path)

    img = square_crop_img(img)
    if sidelength is not None:
        img = cv2.resize(img, (sidelength, sidelength), interpolation=cv2.INTER_NEAREST)


    return img


def load_pose(filename):
    lines = open(filename).read().splitlines()
    if len(lines)==1:
        pose = np.zeros((4,4),dtype=np.float32)
        for i in range(16):
            pose[i//4, i%4] = lines[0].split(" ")[i]
        return pose.squeeze()
    else:
        lines = [[x[0],x[1],x[2],x[3]] for x in (x.split(" ") for x in lines[:4])]
        return np.asarray(lines).astype(np.float32).squeeze()

def load_pts(path):
    with open(path, 'r') as fin:
        lines = [item.rstrip() for item in fin]

        pts = np.array([[float(line.split()[0]), float(line.split()[1]), float(line.split()[2])] for line in lines],
                       dtype=np.float32)

        rgb = np.array([[int(line.split()[6]), int(line.split()[7]), int(line.split()[8])] for line in lines], dtype=np.float32)
        rgb = (rgb/float(255))

        return pts, rgb

def load_labels(path):
    with open(path, 'r') as fin:
        lines = [item.rstrip() for item in fin]

        labels = np.array([[float(line.split()[0])] for line in lines],
                       dtype=np.float32)

        return labels
