import random
import math
from pathlib import Path
import numpy as np
import glob
import json
import skimage.io as sio
from PIL import Image
import matplotlib.pyplot as plt
import cv2
import skimage.transform
import util.data_process as util
import scipy.io as scio
import datasets.transforms as T

import torch
from torch.utils.data import Dataset
import torchvision


def norm(img, mean=None, std=None):
    if mean is None:
        mean = np.mean(img)
    if std is None:
        std = np.std(img)
    img = (img - mean) / std
    return img


def make_transforms():
    return T.Compose([
        T.RandomRotationAndFlip(),
        T.RandomResizedCrop(),
        T.ColorJitter(),
    ])


class WireframeDataset(Dataset):
    def __init__(self, rootdir, split, num_sample=300, fsize=128, maxL=None):
        self.rootdir = rootdir
        if split == "train":
            self.filelist = glob.glob(f"{rootdir}/train/*_0.png")
        else:
            self.filelist = glob.glob(f"{rootdir}/valid/*_0.png")
        self.filelist.sort()
        if maxL is None:
            maxL = int(fsize * 2 ** (1 / 2))

        print(f"n{split}:", len(self.filelist))
        self.split = split
        self.num_sample = num_sample
        self.fsize = fsize
        self.transforms = None if self.split in ["valid", "test"] else make_transforms()

    def __len__(self):
        return len(self.filelist)

    def get_edge(self, image):
        image = skimage.color.rgb2gray(image)
        image = cv2.GaussianBlur(image, (3, 3), 0)
        edge = cv2.normalize(skimage.filters.sobel(image), None, 0, 1, cv2.NORM_MINMAX)

        return edge[None, :, :]

    def __getitem__(self, idx):
        iname = self.filelist[idx]
        img = Image.open(iname)  # RGBA

        with np.load(iname.replace(".png", "_label.npz")) as npz:
            lines = np.random.permutation(npz["lpos"])[:, :, [1, 0]].reshape(-1, 4) * 4  # from [yxyx] to [xyxy]
        # from left to right, from top to bottom
        index = lines[:, 1] > lines[:, 3]
        lines[index] = lines[index][:, [2, 3, 0, 1]]
        index = lines[:, 0] > lines[:, 2]
        lines[index] = lines[index][:, [2, 3, 0, 1]]

        if self.split == "train" and self.transforms is not None:
            lines = torch.tensor(lines)
            img, lines = self.transforms(img, lines)
            lines = lines.numpy()
        # img = img.resize((1024, 1024))
        label_cls, label_lines = util.get_lines_label(lines, 4)  # 1/8
        label_cls64, label_lines64 = util.get_lines_label(lines, 8)  # 1/8
        score_cls, score_lines = util.get_random_pos_neg_lines(lines)
        label_line_map = [util.get_line_map(lines / 512, f)[0] for f in [256, 128, 64, 32]]
        label_junction_map = [util.get_junction_map(lines / 512, f) for f in [256, 128, 64, 32]]
        label_midpoint_map = [util.get_midpoint_map(lines / 512, f) for f in [256, 128, 64, 32]]

        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        img = np.array(img)
        img = norm(img / 255, mean, std)  # RGB
        img = np.rollaxis(img, 2).copy()

        img = torch.tensor(img).float()
        lines = torch.tensor(lines).float()
        label_cls = torch.tensor(label_cls).float()
        label_lines = torch.tensor(label_lines).float()
        label_cls64 = torch.tensor(label_cls64).float()
        label_lines64 = torch.tensor(label_lines64).float()
        score_cls = torch.tensor(score_cls).float()
        score_lines = torch.tensor(score_lines).float()
        label_junction_map = [torch.tensor(m).float() for m in label_junction_map]
        label_line_map = [torch.tensor(m).float() for m in label_line_map]
        label_midpoint_map = [torch.tensor(m).float() for m in label_midpoint_map]
        output_dict = dict(img=img, lines=lines, label_cls=label_cls, label_lines=label_lines,
                           label_cls64=label_cls64, label_lines64=label_lines64,
                           score_cls=score_cls, score_lines=score_lines,
                           label_line_map=label_line_map, label_junction_map=label_junction_map, label_midpoint_map=label_midpoint_map, iname=iname)

        return output_dict


class YUD():
    def __init__(self, rootdir, split, num_sample=300, fsize=128, maxL=None):
        self.rootdir = rootdir
        self.filelist = glob.glob(f"{rootdir}/*/*.jpg")
        self.filelist.sort()

    def __len__(self):
        return len(self.filelist)

    def __getitem__(self, idx):
        iname = self.filelist[idx]
        img = Image.open(iname)  # RGBA

        mat = scio.loadmat(iname.replace(".jpg", "LinesAndVP.mat"))
        lines = np.array(mat["lines"]).reshape(-1, 4)
        index = lines[:, 1] > lines[:, 3]
        lines[index] = lines[index][:, [2, 3, 0, 1]]
        index = lines[:, 0] > lines[:, 2]
        lines[index] = lines[index][:, [2, 3, 0, 1]]
        # from 640x480 to 512x512
        img = img.resize((512, 512))
        lines = lines * np.array([512/640, 512/480, 512/640, 512/480])

        label_cls, label_lines = util.get_lines_label(lines, 4)  # 1/8
        label_cls64, label_lines64 = util.get_lines_label(lines, 8)  # 1/8
        score_cls, score_lines = util.get_random_pos_neg_lines(lines)
        label_line_map = [util.get_line_map(lines / 512, f)[0] for f in [256, 128, 64, 32]]
        label_junction_map = [util.get_junction_map(lines / 512, f) for f in [256, 128, 64, 32]]

        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        img = np.array(img)
        img = norm(img / 255, mean, std)  # RGB
        img = np.rollaxis(img, 2).copy()

        img = torch.tensor(img).float()
        lines = torch.tensor(lines).float()
        label_cls = torch.tensor(label_cls).float()
        label_lines = torch.tensor(label_lines).float()
        label_cls64 = torch.tensor(label_cls64).float()
        label_lines64 = torch.tensor(label_lines64).float()
        score_cls = torch.tensor(score_cls).float()
        score_lines = torch.tensor(score_lines).float()
        label_junction_map = [torch.tensor(m).float() for m in label_junction_map]
        label_line_map = [torch.tensor(m).float() for m in label_line_map]
        output_dict = dict(img=img, lines=lines, label_cls=label_cls, label_lines=label_lines,
                           score_cls=score_cls, score_lines=score_lines,
                           label_line_map=label_line_map, label_junction_map=label_junction_map, iname=iname)
        return output_dict


def generate_grid(scale_fct=128):
    x = np.linspace(0, scale_fct, scale_fct, endpoint=False)
    y = np.linspace(0, scale_fct, scale_fct, endpoint=False)
    X, Y = np.meshgrid(x, y)
    grid = np.dstack((X, Y, X, Y))
    return grid.reshape((-1, 4))


def build(image_set, args):
    if args.dataset_file == "wireframe":
        args.dataset_path = "/home/data3/dataset/wireframe"
        dataset = WireframeDataset(args.dataset_path, image_set, args.num_sample, args.fsize)
    elif args.dataset_file == "yud":
        args.dataset_path = "/home/data3/dataset/YorkUrbanDB"
        dataset = YUD(args.dataset_path, image_set, args.num_sample, args.fsize)
    root = Path(args.dataset_path)
    assert root.exists(), f'provided Line Detection path {root} does not exist'
    return dataset


if __name__ == "__main__":
    dataset = YUD("/disk2/dataset/YorkUrbanDB", "test")
    for data in dataset:
        print(data)
