import random
import math
from pathlib import Path
import numpy as np
import cv2
from multiprocessing import Pool
import skimage.io as sio
import matplotlib.pyplot as plt
import os
import skimage.transform
import numpy.linalg as LA
from scipy.ndimage import gaussian_filter1d
import scipy.io as scio


def get_standard_location(max_L):
    """
    output shape: num_sample x L
    """
    min_L = max_L // 20  # determin which line showld be reserve
    n_theta = 180
    n_rho = max_L
    theta = np.linspace(0, np.pi, n_theta, endpoint=False)
    rho = np.linspace(-2 ** (1 / 2), 2 ** (1 / 2), n_rho, endpoint=True)
    l = np.linspace(-2 ** (1 / 2), 2 ** (1 / 2), max_L, endpoint=True)

    maps = np.zeros((n_theta * n_rho, max_L, 2), dtype=np.float32)  # add (3, 2) for 3 adjcent
    masks = np.zeros((n_theta * n_rho, max_L), dtype=bool)
    theta_rho = np.zeros((n_theta * n_rho, 2), dtype=np.float32)
    keep = np.zeros(n_theta * n_rho, dtype=bool)
    # rho first theta second
    for j in range(n_rho):
        for i in range(n_theta):
            t = theta[i]
            r = rho[j]
            xs = r * np.cos(t) - l * np.sin(t)
            ys = r * np.sin(t) + l * np.cos(t)
            # add adjecent
            # maps[j * n_theta + i, :, 1, 0] = (r - 1) * np.cos(t) - l * np.sin(t)
            # maps[j * n_theta + i, :, 1, 1] = (r - 1) * np.sin(t) + l * np.cos(t)
            # maps[j * n_theta + i, :, 2, 0] = (r + 1) * np.cos(t) - l * np.sin(t)
            # maps[j * n_theta + i, :, 2, 1] = (r + 1) * np.sin(t) + l * np.cos(t)
            # mask
            mask = (xs > 1) | (xs < -1) | (ys > 1) | (ys < -1)
            if sum(~mask) <= min_L:
                continue
            masks[j * n_theta + i, :] = mask
            theta_rho[j * n_theta + i, 0], theta_rho[j * n_theta + i, 1] = t, r
            keep[j * n_theta + i] = sum(~mask) > min_L
            # map
            xs = np.linspace(xs[~mask][0], xs[~mask][-1], max_L)
            ys = np.linspace(ys[~mask][0], ys[~mask][-1], max_L)
            maps[j * n_theta + i, :, 0] = xs  # add [0, 0] for adjcent
            maps[j * n_theta + i, :, 1] = ys
            masks[j * n_theta + i, :] = False
            # roll for keep False mask to the front
            # if sum(~mask) > 0:
            #     n = np.where(~mask)[0][0]
            #     maps[j * n_theta + i] = np.roll(maps[j * n_theta + i], -n, 0)
            #     masks[j * n_theta + i] = np.roll(masks[j * n_theta + i], -n, 0)
    # maps = maps[keep]
    # masks = masks[keep]
    # theta_rho = theta_rho[keep]
    print('Hough transform mapping dict has generated.')
    # check
    # plt.plot([-1, -1], [-1, 1], c='g')
    # plt.plot([-1, 1], [-1, -1], c='g')
    # plt.plot([-1, 1], [1, 1], c='g')
    # plt.plot([1, 1], [-1, 1], c='g')
    # ax = plt.gca()
    # ax.set_aspect('equal', adjustable='box')
    # for k in [4000, 5000, 6000]:
    #     line, mask = maps[k], masks[k]
    #     plt.scatter(line[:, 0], line[:, 1], c='b')
    #     line = line[~mask]
    #     plt.scatter(line[:, 0], line[:, 1], c='r')
    # plt.show()
    return maps, masks, theta_rho, ~keep.reshape(n_rho, n_theta)  # mask True presents ignore


def get_dis(lines, theta_rhos):
    ep1 = np.abs(np.cos(theta_rhos[:, 0:1]) @ lines[:, 0:1].T + np.sin(theta_rhos[:, 0:1]) * lines[:, 1:2].T - theta_rhos[:, 1:2])
    ep2 = np.abs(np.cos(theta_rhos[:, 0:1]) @ lines[:, 2:3].T + np.sin(theta_rhos[:, 0:1]) * lines[:, 3:4].T - theta_rhos[:, 1:2])
    return np.where(ep1 > ep2, ep1, ep2)


def get_len(lines):
    return ((lines[:, 1] - lines[:, 3]) ** 2 + (lines[:, 0] - lines[:, 2]) ** 2) ** 0.5


def get_midpoint_map(lines, fsize):
    junctions = np.round(lines.reshape((-1, 2, 2)).mean(1) * fsize).astype(int).clip(0, fsize - 1)
    junction_map = np.zeros((fsize, fsize), dtype=np.float32)
    for fun0, fun1 in [(np.floor, np.floor), (np.floor, np.ceil), (np.ceil, np.floor), (np.ceil, np.ceil)]:
        x, y = fun0(junctions[:, 0]), fun1(junctions[:, 1])
        junction_map[y.clip(0, fsize-1).astype(int), x.clip(0, fsize-1).astype(int)] = \
            np.sqrt((x - junctions[:, 0]) ** 2 + (y - junctions[:, 1]) ** 2) < 1.
    return junction_map


def get_junction_map(lines, fsize):
    junctions = lines.reshape((-1, 2)) * (fsize - 1)
    junction_map = np.zeros((fsize, fsize), dtype=np.float32)
    for fun0, fun1 in [(np.floor, np.floor), (np.floor, np.ceil), (np.ceil, np.floor), (np.ceil, np.ceil)]:
        x, y = fun0(junctions[:, 0]), fun1(junctions[:, 1])
        junction_map[y.clip(0, fsize-1).astype(int), x.clip(0, fsize-1).astype(int)] = \
            np.sqrt((x - junctions[:, 0]) ** 2 + (y - junctions[:, 1]) ** 2) < 1.
    # junctions = np.floor(lines.reshape((-1, 2)) * fsize).astype(int).clip(0, fsize-1)
    # junction_map = np.zeros((fsize, fsize), dtype=np.float32)
    # for j in junctions:
    #     junction_map[tuple(j)] = 1.
    return junction_map


def get_line_map(lines, fsize):
    junctions = np.round(lines * fsize).astype(int).clip(0, fsize-1)
    line_map = np.zeros((fsize, fsize))
    line_direction_map = np.zeros((fsize, fsize), dtype=np.float32)
    for j in junctions:
        r, c, v = skimage.draw.line_aa(j[1], j[0], j[3], j[2])  # anti-aliasing
        line_map[r, c] = np.maximum(line_map[r, c], v > 0.33)
        theta = np.arctan2(j[0]-j[2], j[3]-j[1])
        if theta < 0:
            theta = theta + np.pi
        line_direction_map[r, c] = theta
    line_map = line_map.astype(np.float32)
    return line_map, line_direction_map


def get_corresponding_maps(lines, standard_maps, standard_masks, standard_theta_rhos, fsize):
    junction_map = get_junction_map(lines, fsize)
    line_map, line_direction_map = get_line_map(lines, fsize)
    lines = lines * 2 - 1
    dist_map = get_dis(lines, standard_theta_rhos)

    num, l = standard_maps.shape[0], standard_maps.shape[1]
    clss = -np.ones(num, dtype=np.float32)
    root_map = np.zeros((num, l), dtype=np.float32)
    dis_map = np.zeros((num, l, 2), dtype=np.float32)
    lp_map = np.zeros((num, l), dtype=np.float32)
    rp_map = np.zeros((num, l), dtype=np.float32)
    ep_map = np.zeros((num, l), dtype=np.float32)

    clss[np.argmin(dist_map, axis=0)] = 1  # min
    clss[(np.min(dist_map, axis=1) > 2 / fsize * 1)] = 0
    # clss[(np.min(dist_map, axis=1) < 2 / fsize * 2) & (np.min(dist_map, axis=1) > 2 / fsize * 1)] = 0  # 2.5 pixel length
    # standard_maps = standard_maps[:, :, 0, :]  # for 3 samples
    for i, line in enumerate(lines):
        index = np.argmin(dist_map[:, i])
        sample_point = standard_maps[index][~standard_masks[index]]
        flatten_line = line.reshape((-1, 2))
        dis = sample_point[:, None] - flatten_line[None]
        dis = np.sqrt(dis[:, :, 0] ** 2 + dis[:, :, 1] ** 2)
        min_dis_index = np.argmin(dis, axis=0)
        e1, e2 = min(min_dis_index), max(min_dis_index)
        # root_map[index, (e1+e2)//2] = 1
        # root_map[index, e1:e2 + 1] = np.convolve(root_map[index, e1:e2 + 1], np.array([0.5, 1, 0.5]))[1:-1]  # add gaussian blur
        root_map[index, e1:e2] = 1
        dis_map[index, root_map[index, :] > 1e-5, 0] = e1 - np.where(root_map[index, :] > 1e-5)[0]
        dis_map[index, root_map[index, :] > 1e-5, 1] = e2 - np.where(root_map[index, :] > 1e-5)[0]
        dis_map[index, root_map[index, :] > 1e-5] = dis_map[index, root_map[index, :] > 1e-5] * 2 / (fsize * 2 ** 0.5)
        lp_map[index, np.where(dis[:, np.argmin(min_dis_index)] < 2 / fsize)] = 1
        rp_map[index, np.where(dis[:, np.argmax(min_dis_index)] < 2 / fsize)] = 1
        lp_map[index, e1] = 1
        rp_map[index, e2] = 1

        ep_map[index, np.where(dis[:, np.argmin(min_dis_index)] < 2 / fsize)] = 1
        ep_map[index, np.where(dis[:, np.argmax(min_dis_index)] < 2 / fsize)] = 1
        ep_map[index, e1] = 1
        ep_map[index, e2] = 1

    return clss, root_map, dis_map, junction_map, line_map, line_direction_map, lp_map, rp_map, ep_map


def get_sample_location(num_sample, locations, masks, clss, root_map, dis_map, lp_map, rp_map, ep_map):
    num_positive = np.sum(clss == 1)
    positive_sample_indice = np.where(clss == 1)[0]
    if num_positive > num_sample:
        num_positive = num_sample
        positive_sample_indice = np.random.choice(positive_sample_indice, num_positive, replace=False)
    num_negative = num_sample - num_positive

    clss_postive = clss[positive_sample_indice]
    locations_positive = locations[positive_sample_indice]
    masks_positive = masks[positive_sample_indice]
    root_map_positive = root_map[positive_sample_indice]
    dis_map_positive = dis_map[positive_sample_indice]
    lp_map_positive = lp_map[positive_sample_indice]
    rp_map_positive = rp_map[positive_sample_indice]
    ep_map_positive = ep_map[positive_sample_indice]

    negative_indice = np.where(clss == 0)[0]
    negative_sample_indice = np.random.choice(negative_indice, num_negative, replace=False)
    locations_negative = locations[negative_sample_indice]
    masks_negative = masks[negative_sample_indice]
    clss_negative = clss[negative_sample_indice]
    root_map_negative = root_map[negative_sample_indice]
    dis_map_negative = dis_map[negative_sample_indice]
    lp_map_negative = lp_map[negative_sample_indice]
    rp_map_negative = rp_map[negative_sample_indice]
    ep_map_negative = ep_map[negative_sample_indice]

    shuffle_indice = np.random.choice(np.arange(0, num_sample, 1), num_sample, replace=False)
    locations = np.concatenate((locations_positive, locations_negative), axis=0)[shuffle_indice]
    masks = np.concatenate((masks_positive, masks_negative), axis=0)[shuffle_indice]
    clss = np.concatenate((clss_postive, clss_negative), axis=0)[shuffle_indice]
    root_map = np.concatenate((root_map_positive, root_map_negative), axis=0)[shuffle_indice]
    dis_map = np.concatenate((dis_map_positive, dis_map_negative), axis=0)[shuffle_indice]
    lp_map = np.concatenate((lp_map_positive, lp_map_negative), axis=0)[shuffle_indice]
    rp_map = np.concatenate((rp_map_positive, rp_map_negative), axis=0)[shuffle_indice]
    ep_map = np.concatenate((ep_map_positive, ep_map_negative), axis=0)[shuffle_indice]
    return locations, masks, clss, root_map, dis_map, lp_map, rp_map, ep_map


def get_CN4_grid(W1=64, H1=64, W2=64, H2=64, n=16):
    x1 = np.linspace(0, 1, W1, endpoint=False)
    y1 = np.linspace(0, 1, H1, endpoint=False)
    x2 = np.linspace(0, 1, W2, endpoint=False)
    y2 = np.linspace(0, 1, H2, endpoint=False)
    x1y1 = np.stack(np.meshgrid(x1, y1), axis=-1).reshape(-1, 2)
    x2y2 = np.stack(np.meshgrid(x2, y2), axis=-1).reshape(-1, 2)
    sampled_ep = np.concatenate((x1y1[:, None].repeat(W2 * H2, 1), x2y2[None].repeat(W1 * H1, 0)), axis=-1).reshape(-1, 4)
    sampled_ep = sampled_ep.reshape(-1, 2, 2).transpose(0, 1, 2).reshape(-1, 2)
    sampled_grid = np.linspace(sampled_ep[:, 0], sampled_ep[:, 1], n, endpoint=True).reshape(n, -1, 2).transpose(1, 0, 2)
    return sampled_grid


def get_lines_label(lines, s=8):
    l = get_len(lines)
    index = np.argsort(-l)
    l, lines = l[index], lines[index]

    N = (512 // s) ** 2
    xm, ym = (lines[:, 0] + lines[:, 2]) / 2, (lines[:, 1] + lines[:, 3]) / 2
    xmi, ymi = xm // s, ym // s
    c = ymi * 512 / s + xmi
    z = np.bincount(c.astype(int), minlength=N)
    label_cls = (z > 0).astype(int)

    label_reg = np.zeros((N, 4))
    label_reg[c.astype(int)] = lines / 512
    # label_reg[c.astype(int)] = (lines - np.stack([xmi, ymi, xmi, ymi], axis=-1) * s) / 512  # for ablation study
    return label_cls, label_reg


def get_lines_label_dense(lines, s=8):
    l = get_len(lines)
    index = np.argsort(-l)
    l, lines = l[index], lines[index]

    N = (512 // s) ** 2
    xm, ym = (lines[:, 0] + lines[:, 2]) / 2, (lines[:, 1] + lines[:, 3]) / 2
    xmi, ymi = xm // s, ym // s
    c = ymi * 512 / s + xmi
    z = np.bincount(c.astype(int), minlength=N)
    label_cls = (z > 0).astype(int)

    label_reg = np.zeros((N, 4))

    junctions = np.round(lines // s).clip(0, 512 // s - 1).astype(int)
    for j, l in zip(junctions, lines / 512):
        r, c, v = skimage.draw.line_aa(j[1], j[0], j[3], j[2])  # anti-aliasing
        r, c = r[v > 0.33], c[v > 0.33]
        label_reg[r * 64 + c] = l

    return label_cls, label_reg


def adjust_contrast_adaptive(image, clip_limit=2.0, tile_grid_size=(8, 8)):
    """
    使用CLAHE方法对图像进行自适应对比度调整

    参数:
    image: 输入图像
    clip_limit: 对比度限制的阈值（通常为2.0或4.0）
    tile_grid_size: CLAHE算法中使用的网格大小

    返回:
    adjusted_image: 对比度调整后的图像（注意：Matplotlib需要BGR到RGB的转换）
    """
    # 转换为YCrCb色彩空间，因为CLAHE通常在亮度通道（Y通道）上表现最好
    img_ycrcb = cv2.cvtColor(image, cv2.COLOR_RGB2YCrCb)

    # 分离YCrCb通道
    y, cr, cb = cv2.split(img_ycrcb)

    # 应用CLAHE
    clahe = cv2.createCLAHE(clipLimit=clip_limit, tileGridSize=tile_grid_size)
    cl_y = clahe.apply(y)

    # 合并通道
    cl_ycrcb = cv2.merge((cl_y, cr, cb))

    # 注意：Matplotlib默认使用RGB色彩空间，所以我们需要将YCrCb转换回BGR，然后再转换为RGB
    adjusted_image = cv2.cvtColor(cl_ycrcb, cv2.COLOR_YCrCb2BGR)
    adjusted_image = cv2.cvtColor(adjusted_image, cv2.COLOR_BGR2RGB)  # 转换为RGB以在Matplotlib中正确显示

    return adjusted_image


def generate_segments(line_segments, threshold_range, num_segments_per_line=4):
    """
    利用numpy的矩阵运算优势，为每条输入线段生成指定数量的线段。
    每条线段表示为一行1*4的numpy向量，最终组合成一个Nx4的numpy矩阵，其中N是输出线段的数量。
    threshold_range 是一个包含两个元素的元组，表示阈值的上线和下线。
    """
    # 解析输入直线段的端点，并组合成两个矩阵
    A_matrix = np.array([line_segment[:2] for line_segment in line_segments])
    B_matrix = np.array([line_segment[2:] for line_segment in line_segments])

    # 解析阈值范围
    a, b = threshold_range

    # 生成随机距离和方向
    num_lines = len(line_segments)
    max_distances = np.random.uniform(b, a, (num_lines, num_segments_per_line))
    d1 = np.random.uniform(0, np.sqrt(max_distances))
    d2 = np.sqrt(max_distances - d1 ** 2)
    angles1 = np.random.uniform(0, 2 * math.pi, (num_lines, num_segments_per_line))
    angles2 = np.random.uniform(0, 2 * math.pi, (num_lines, num_segments_per_line))

    # 生成P1和P2点
    P1_x = A_matrix[:, 0:1] + d1 * np.cos(angles1)
    P1_y = A_matrix[:, 1:2] + d1 * np.sin(angles1)
    P2_x = B_matrix[:, 0:1] + d2 * np.cos(angles2)
    P2_y = B_matrix[:, 1:2] + d2 * np.sin(angles2)

    # 组合P1和P2点成线段
    segments = np.stack((P1_x.flatten(), P1_y.flatten(), P2_x.flatten(), P2_y.flatten()), axis=-1)
    return segments.reshape(-1, 4)


def get_random_pos_neg_lines(lines, max_num=2000):
    lines = lines / 4
    times_pos = 1
    times_neg = 1
    lines_pos = generate_segments(lines, (0, 2), times_pos)
    lines_neg = generate_segments(lines, (8, 16), times_neg)  # 有的5-15是其他GT的inlier，可以在去除
    num_other = max_num - lines.shape[0] * (times_neg + times_pos)
    lines_other = np.random.uniform(0, 128, (num_other, 4))
    score_lines = np.concatenate((lines_pos, lines_neg, lines_other), axis=0)
    # score_cls = -np.ones(max_num)
    score_cls = np.zeros(max_num)
    score_cls[:times_pos * lines.shape[0]] = 1
    # score_cls[times_pos * lines.shape[0]:times_neg * lines.shape[0]] = 0
    shuffled_indices = np.arange(max_num)
    np.random.shuffle(shuffled_indices)
    score_cls, score_lines = score_cls[shuffled_indices], score_lines[shuffled_indices] / 128
    return score_cls, score_lines


