"""
@Description :   Base 数据集
@Author      :   tqychy 
@Time        :   2025/01/13 18:46:16
"""
import sys

sys.path.append("./")
sys.path.append("./dataset")
import pickle
import random

import cv2
import numpy as np
import torch
from noise import pnoise2
from scripts.encoders import *
from torch.utils.data import Dataset
from torchvision import transforms
from tqdm import tqdm


def get_adjacent2(boundary, max_len, k=1):
    """
    The input is a set of contour points, and the function will get an adjacency matrix constructed
    from the set of points.
    :param boundary: type = Ndarray.
    :param sparse: type = bool
    :return: an adjacency matrix.
    返回非稀疏矩阵
    """
    n = len(boundary)
    adjacent_matrix = np.eye(n)
    temp = np.eye(n)
    for i in range(k):
        adjacent_matrix += np.roll(temp, i + 1, axis=0)
        adjacent_matrix += np.roll(temp, -i - 1, axis=0)
    temp = np.zeros((max_len, max_len))
    temp[:n, :n] = adjacent_matrix

    return torch.from_numpy(temp)

def generate_perlin_noise_mask(h, w, scale=100, octaves=6, persistence=0.5, lacunarity=2.0, threshold=0.3):
    """生成基于 Perlin 噪声的霉点掩码"""
    noise_map = np.zeros((h, w), dtype=np.float32)
    for y in range(h):
        for x in range(w):
            noise_map[y, x] = pnoise2(x / scale, y / scale, octaves=octaves, 
                                    persistence=persistence, lacunarity=lacunarity)
    
    # 归一化到 [0, 1]
    noise_map = (noise_map - noise_map.min()) / (noise_map.max() - noise_map.min())
    # 阈值化生成霉点形状
    mask = (noise_map > threshold).astype(np.uint8) * 255
    return mask

def add_stains(image, num_stains=10, num_molds=5):
        h, w, _ = image.shape
        for _ in range(num_stains):
            x = random.randint(0, w - 1)
            y = random.randint(0, h - 1)
            size = random.randint(1, 5)
            color = (random.randint(0, 50), random.randint(0, 50), random.randint(0, 50))
            cv2.circle(image, (x, y), size, color, -1)
        # for _ in range(num_molds):
        #     x = random.randint(0, w - 1)
        #     y = random.randint(0, h - 1)
        #     # size = random.randint(10, 30)
        #     size = random.randint(1, 5)
        #     color = (random.randint(0, 100), random.randint(100, 200), random.randint(0, 100))
        #     mold = np.zeros_like(image)
        #     cv2.rectangle(mold, (x - size // 2, y - size // 2), (x + size // 2, y + size // 2), color, -1)
        #     mold = cv2.GaussianBlur(mold, (15, 15), 0)
        #     image = cv2.addWeighted(image, 1, mold, 0.5, 0)
        for _ in range(num_molds):
            # 随机区域大小
            size = random.randint(5, 10)
            # 确保切片在图像边界内
            x = random.randint(size // 2, w - size // 2)
            y = random.randint(size // 2, h - size // 2)
            
            # 计算实际切片尺寸
            slice_y_start = max(0, y - size // 2)
            slice_y_end = min(h, y + size // 2)
            slice_x_start = max(0, x - size // 2)
            slice_x_end = min(w, x + size // 2)
            
            # 计算切片区域的实际尺寸
            slice_h = slice_y_end - slice_y_start
            slice_w = slice_x_end - slice_x_start

            if slice_h <= 0 or slice_w <= 0:
                continue
            
            # 生成 Perlin 噪声掩码，匹配切片尺寸
            mold_mask = generate_perlin_noise_mask(slice_h, slice_w, scale=50, threshold=0.4)
            
            # 创建霉点图像
            mold = np.zeros((h, w, 3), dtype=np.uint8)
            color = (random.randint(0, 100), random.randint(100, 200), random.randint(0, 100))  # 绿色或棕色
            
            # 将 2D 掩码扩展为 3D 以匹配 RGB 通道
            mask_3d = mold_mask[:, :, np.newaxis] / 255.0  # 归一化到 [0, 1]
            # 应用掩码到切片区域
            mold[slice_y_start:slice_y_end, slice_x_start:slice_x_end] = mask_3d * color
            
            # 应用高斯模糊
            mold = cv2.GaussianBlur(mold, (15, 15), 0)
            
            # 叠加到原始图像
            alpha = 0.5
            image = cv2.addWeighted(image, 1, mold, alpha, 0)
        return image

def generate_new_edge(start_point, end_point, num_points, max_offset, img_shape):
    """
    在起始点和结束点之间生成一组新点，新点在原始连线基础上随机偏移。
    确保所有新点都在图像内。

    参数：
    start_point (numpy.ndarray): 破损段的起始点，形状为 (2,)
    end_point (numpy.ndarray): 破损段的结束点，形状为 (2,)
    num_points (int): 新边缘的点数
    max_offset (int): 最大随机偏移量
    img_shape (tuple): 图像的尺寸 (height, width)

    返回：
    numpy.ndarray: 新边缘的点，形状为 (num_points, 2)
    """
    # 计算方向向量和长度
    direction = end_point - start_point
    length = np.linalg.norm(direction)
    if length == 0:
        return []
    direction = direction / length

    new_points = []
    for i in range(num_points):
        # 在原始连线上均匀采样
        t = i / (num_points - 1)
        base_point = start_point + t * direction * length
        # 添加随机偏移
        offset = np.random.uniform(-max_offset, max_offset, size=2)
        new_point = base_point + offset

        # 限制新点在图像边界内
        new_point[0] = np.clip(new_point[0], 0, img_shape[1] - 1)  # x 坐标
        new_point[1] = np.clip(new_point[1], 0, img_shape[0] - 1)  # y 坐标

        new_points.append(new_point)

    return np.array(new_points)

def add_corrosion(img_ori, full_pcd, max_offset=5):
    """
    模拟边缘破损，将破损段替换为一段新的、随机的、距离原始边缘较近的新边界。
    并将此效果应用到原始图像上。

    参数：
    img_ori (numpy.ndarray): 原始图像，形状为 (H, W, 3)
    full_pcd (numpy.ndarray): 边缘点云数据，形状为 (N, 2)
    corrosion_length (int): 破损段的长度
    max_offset (int): 新边缘的最大偏移量

    返回：
    numpy.ndarray: 修改后的图像
    numpy.ndarray: 修改后的点云数据
    """
    really_add = random.random()
    if really_add > 0.1:
        return img_ori, full_pcd
    N = full_pcd.shape[0]
    corrosion_length = random.randint(2, 15)
    if N <= corrosion_length:
        return img_ori, full_pcd

    # 随机选择破损段的起始点
    start_idx = random.randint(0, N - 1)
    end_idx = (start_idx + corrosion_length) % N  # 支持闭合点云

    start_point = full_pcd[start_idx]
    end_point = full_pcd[end_idx]

    # 生成新边缘点
    num_points = corrosion_length
    img_shape = img_ori.shape[:2]
    new_points = generate_new_edge(start_point, end_point, num_points, max_offset, img_shape)
    if len(new_points) == 0:
        return img_ori, full_pcd
    new_points_tensor = torch.from_numpy(new_points).long()

    # 替换点云中的破损段
    if start_idx < end_idx:
        full_pcd[start_idx:end_idx] = new_points_tensor
    else:
        # 处理跨越数组边界的情况（如闭合点云）
        full_pcd[start_idx:] = new_points_tensor[:N - start_idx]
        full_pcd[:end_idx] = new_points_tensor[N - start_idx:]

    # 在图像上标记新边缘的点
    for point in new_points:
        x, y = int(point[0]), int(point[1])
        if 0 <= x < img_shape[1] and 0 <= y < img_shape[0]:
            img_ori[y, x] = [0, 0, 0]

    return img_ori, full_pcd


class BaseDataSet(Dataset):
    def __init__(self, data_path: str, *args, calc_adjs=True):
        super().__init__()
        self.cfg, self.logger = args
        self.trans = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize(224),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]),
        ])

        with open(data_path, 'rb') as f:
            self.data = pickle.load(f)

        self.patch_size = self.cfg.NET.PATCH_SIZE
        self.c_model = self.cfg.NET.RESGCN.ENCODER_TYPE
        self.max_points = self.cfg.DATASET.CONTOUR_MAX_LEN
        self.n = len(self.data['img_all'])
        self.height_max = self.cfg.DATASET.TEXTURE_IMAGE_SIZE
        self.width_max = self.height_max
        self.long = list(map(len, self.data['full_pcd_all']))

        # 预计算邻接矩阵（稀疏存储）
        self.adj_all = []
        if calc_adjs:
            for i in tqdm(range(self.n), desc="预计算邻接矩阵"):
                adj = get_adjacent2(
                    self.data['full_pcd_all'][i], self.max_points, k=8)
                adj_coo = adj.to_sparse_coo()
                self.adj_all.append(adj_coo.indices().clone())

    def __len__(self):
        return self.n

    def process_image(self, img_ori, shape):
        """处理单个图像并应用变换"""
        temp_empty_img = np.zeros(
            (self.height_max, self.width_max, 3), dtype=np.uint8)
        temp_empty_img[:shape[1], :shape[0]] = cv2.cvtColor(
            img_ori, cv2.COLOR_BGR2RGB)
        return self.trans(temp_empty_img)

    def get_patches(self, img_ori, full_pcd, patch_type='texture'):
        """动态生成纹理或轮廓的patch"""
        img = cv2.cvtColor(img_ori, cv2.COLOR_BGR2RGB) if patch_type == 'texture' \
            else (img_ori != 0).all(-1)
        img = np.pad(img, ((0, 20), (0, 20), (0, 0)) if patch_type == 'texture'
                     else ((0, 20), (0, 20)), mode='constant')

        if patch_type == 'texture':
            img_tensor = torch.from_numpy(img).permute(
                2, 0, 1).unsqueeze(0) / 255.0
            return img_patch_encoder(img_tensor, full_pcd.unsqueeze(0), self.patch_size)[0]
        else:
            img_tensor = torch.from_numpy(img).float().unsqueeze(0)
            if self.c_model == 'l':
                return pre_encoder1(img_tensor, full_pcd.unsqueeze(0), self.patch_size)[0]
            elif self.c_model == 'io':
                return pre_encoder2(img_tensor, full_pcd.unsqueeze(0), self.patch_size)[0]
            else:
                return pre_encoder3(img_tensor, full_pcd.unsqueeze(0), self.patch_size)[0]

    def __getitem__(self, idx):
        # 获取原始数据
        full_pcd = torch.from_numpy(self.data['full_pcd_all'][idx])

        # 添加污点
        img_all = self.data['img_all'][idx]
        if self.cfg.DATASET.ADD_STRAINS:
            img_all = add_stains(img_all)
            img_all, full_pcd = add_corrosion(img_all, full_pcd)
        
        # 填充至相同长度
        full_pcd_padded = torch.zeros((self.max_points, 2))
        full_pcd_padded[:len(full_pcd)] = full_pcd

        # 处理图像
        img_tensor = self.process_image(
            img_all,
            self.data['shape_all'][idx][:2]
        )

        # 生成patches
        t_input = self.get_patches(img_all, full_pcd_padded, 'texture')
        c_input = self.get_patches(img_all, full_pcd_padded, 'contour')

        # 邻接矩阵（转换为密集张量）
        if len(self.adj_all) > 0:
            adj_indices = self.adj_all[idx]
            adj_matrix = torch.sparse_coo_tensor(
                adj_indices,
                torch.ones(adj_indices.shape[1]),
                (self.max_points, self.max_points)
            ).to_dense()
        else:
            adj_matrix = torch.tensor(0.)

        full_pcd_padded = full_pcd_padded / (self.height_max / 2.0) - 1

        return {
            'img': img_tensor,
            'full_pcd': full_pcd_padded,
            'c_input': c_input,
            't_input': t_input,
            'adj': adj_matrix,
            'factor': torch.tensor(1.0),
            'index': idx
        }


if __name__ == "__main__":
    import argparse

    from config.default import cfg
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--config_path",
        type=str,
        default="./config/llmco4mr_pairing/1.yaml"
    )
    args = parser.parse_args()
    cfg.merge_from_file(args.config_path)

    data_path = "dataset/2192_all/build_dataset_2192_pairing_test_set.pkl"
    BaseDataSet(data_path, cfg, None)
