import os
os.environ['MPLBACKEND'] = 'agg'
os.environ['DISPLAY'] = ''
from PIL import Image
import numpy as np
import torch
from torchvision import transforms
from collections import Counter
from tqdm import tqdm
from skimage.filters import sobel
import matplotlib.pyplot as plt
from diffmat.optim.descriptor import TextureDescriptor
from config import MATERIAL_ROOT_DIR


def search_material(search_img, mat_category, td, device='cuda', n_top=10):
    n_samples = 5
    search_img = search_img.to(device)
    search_val = td.evaluate(search_img)
    batch = search_val.shape[0]

    root_dir = os.path.join(MATERIAL_ROOT_DIR, mat_category)
    dist_dict = {}
    print("Searching for matching material...")
    for mat_name in tqdm(sorted(os.listdir(root_dir))):
        mat_dir = os.path.join(root_dir, mat_name, "sampled")
        for i in range(n_samples):
            td_val_1 = torch.load(os.path.join(mat_dir, "relighting_td", f"params_{i}.pt")).to(device)
            # td_val_1 = torch.load(os.path.join(mat_dir, "relighting_td", f"params_{i}.pt")).to(device)
            # td_val_2 = torch.load(os.path.join(mat_dir, "relighting_2x_td", f"params_{i}.pt")).to(device)
            # td_val = torch.cat([td_val_1, td_val_2], dim=0)
            # dist = torch.nn.functional.l1_loss(search_val.expand_as(td_val), td_val)
            dist = torch.nn.functional.l1_loss(search_val, td_val_1)
            dist_dict[mat_name + "-params_" + str(i)] = dist.item()

    # get the top 10 similar images
    top = sorted(dist_dict.items(), key=lambda x: x[1])[:n_top]
    return top


def search_all_material(search_roughness, search_metallic, td, device='cuda', n_top=3):
    n_samples = 5
    rotate_list = [
        transforms.RandomRotation([90, 90]),
        transforms.RandomRotation([180, 180]),
        transforms.RandomRotation([270, 270]),
    ]

    search_roughness = search_roughness.to(device)  # [1, 3, 512, 512]
    search_metallic = search_metallic.to(device)  # [1, 3, 512, 512]
    # search_img = torch.cat([search_roughness, search_metallic], dim=0)
    search_val_roughness = td.evaluate(search_roughness)
    search_val_metallic = td.evaluate(search_metallic)

    dist_dict = {}
    print("Searching for matching material...")
    for mat_category in sorted(tqdm(os.listdir(MATERIAL_ROOT_DIR))):
        root_dir = os.path.join(MATERIAL_ROOT_DIR, mat_category)
        for mat_name in sorted(os.listdir(root_dir)):
            mat_dir = os.path.join(root_dir, mat_name, "sampled")
            for i in range(n_samples):
                td_val_1 = torch.load(os.path.join(mat_dir, "roughness_td", f"params_{i}.pt")).to(device)
                td_val_2 = torch.load(os.path.join(mat_dir, "metallic_td", f"params_{i}.pt")).to(device)

                dist1 = torch.nn.functional.l1_loss(search_val_roughness, td_val_1, reduction='none').mean(dim=1).min().item()
                dist2 = torch.nn.functional.l1_loss(search_val_metallic, td_val_2, reduction='none').mean(dim=1).min().item()
                dist_dict[mat_category+"-"+mat_name + "-params_" + str(i)] = (dist1 + dist2) / 2

    # get the top 10 similar images
    top = sorted(dist_dict.items(), key=lambda x: x[1])[:n_top]
    print(top)
    return top


def match_all_material_image(search_roughness, search_metallic, texture_descriptor=None, device='cuda'):
    if texture_descriptor is None:
        texture_descriptor = TextureDescriptor(device)

    # 直接返回最佳匹配的材质名称和参数编号
    top10 = search_all_material(search_roughness, search_metallic, texture_descriptor, device, n_top=9)

    # 解析top10里的所有材质名称，选择重复最多的材质，并选择其中距离最小的一项
    material_names = []
    material_to_item = {}

    for item, dist in top10:
        # 从item中解析出材质名称和参数编号
        parts = item.split('-params_')
        mat_cat_name = parts[0]
        param_i = parts[1]

        material_names.append(mat_cat_name)

        # 为每个材质保存距离最小的项
        if mat_cat_name not in material_to_item or dist < material_to_item[mat_cat_name][1]:
            material_to_item[mat_cat_name] = (item, dist, param_i)

    # 计算每种材质出现的次数
    material_counts = Counter(material_names)

    # 找出出现次数最多的材质，如果有多个，选择距离最小的
    most_common_material = material_counts.most_common(1)[0][0]

    # 获取该材质中距离最小的项
    best_match = material_to_item[most_common_material]
    best_mat_cate_name = most_common_material
    best_mat_cate = best_mat_cate_name.split('-')[0]
    best_mat_name = best_mat_cate_name.split('-')[1]
    best_param_i = best_match[2]

    print(f"材质类别: {best_mat_cate_name}")
    print(f"材质名: {best_mat_name}")
    print(f"最佳匹配参数编号(i): {best_param_i}")

    return best_mat_cate, best_mat_name, best_param_i


def prepare_crop_areas(mask, normal_map, resolution):
    start_size = resolution // 2
    min_size = max(64, resolution // 32)  # 最小尺寸

    best_result, best_score = find_material_sample_patch(
        mask,
        normal_map,
        max_size=start_size,
        min_size=min_size,
        mask_threshold=1.0,  # 稍微放宽mask要求
        score_threshold=0.02,  # 法线变化阈值
        normal_angle_threshold=0.05,
    )

    if best_result is not None:
        i, j, size = best_result
        # 创建切片对象
        slice_indices = (slice(i, i + size), slice(j, j + size))

        print(f"找到最佳裁剪区域: 位置=({i},{j}), 尺寸={size}x{size}, 法线变化分数={best_score:.4f}")
        return slice_indices
    else:
        best_result, best_score = find_material_sample_patch(
            mask,
            normal_map,
            max_size=start_size//2,
            min_size=min_size,
            mask_threshold=1.0,  # 稍微放宽mask要求
            score_threshold=0.08,  # 法线变化阈值
            normal_angle_threshold=0.2,
        )
        print("未找到最佳裁剪区域, 降低阈值后重新搜索")
        if best_result is not None:
            i, j, size = best_result
            # 创建切片对象
            slice_indices = (slice(i, i + size), slice(j, j + size))

            print(f"找到最佳裁剪区域: 位置=({i},{j}), 尺寸={size}x{size}, 法线变化分数={best_score:.4f}")
            return slice_indices
        else:
            print("仍未找到最佳裁剪区域")
            return None


def find_material_sample_patch(mask, normal_map, max_size=384, min_size=64,
                               mask_threshold=0.95, score_threshold=0.05,
                               normal_angle_threshold=0.1):
    """使用bounding box优化的正方形搜索算法"""
    h, w = mask.shape

    # 1. 计算mask的bounding box
    rows = np.any(mask, axis=1)
    cols = np.any(mask, axis=0)
    if not np.any(rows) or not np.any(cols):
        return None, float('inf')  # 没有有效mask区域

    y_min, y_max = np.where(rows)[0][[0, -1]]
    x_min, x_max = np.where(cols)[0][[0, -1]]

    # bounding box的尺寸
    bb_height = y_max - y_min + 1
    bb_width = x_max - x_min + 1

    # 2. 根据bounding box调整搜索参数
    max_possible_size = min(bb_height, bb_width)
    max_size = min(max_size, max_possible_size)

    # 如果最大可能尺寸小于最小要求尺寸，直接返回
    if max_possible_size < min_size:
        return None, float('inf')

    # 3. 预计算梯度
    grad_magnitude = np.zeros_like(mask, dtype=float)
    for c in range(3):
        grad_x = sobel(normal_map[..., c], axis=1)
        grad_y = sobel(normal_map[..., c], axis=0)
        grad_magnitude += np.sqrt(grad_x ** 2 + grad_y ** 2)

    # 4. 预计算积分图
    mask_integral = np.cumsum(np.cumsum(mask, axis=0), axis=1)
    grad_integral = np.cumsum(np.cumsum(grad_magnitude, axis=0), axis=1)

    # 5. 法向量变化的快速估计
    normal_variation = np.zeros_like(mask, dtype=float)
    for c in range(3):
        normal_variation += np.abs(np.diff(normal_map[..., c], axis=0, append=0)) + \
                            np.abs(np.diff(normal_map[..., c], axis=1, append=0))
    normal_var_integral = np.cumsum(np.cumsum(normal_variation, axis=0), axis=1)

    def get_area_sum(integral, x1, y1, x2, y2):
        """使用积分图计算区域和"""
        if x1 == 0 and y1 == 0:
            return integral[y2, x2]
        elif x1 == 0:
            return integral[y2, x2] - integral[y1 - 1, x2]
        elif y1 == 0:
            return integral[y2, x2] - integral[y2, x1 - 1]
        else:
            return (integral[y2, x2] - integral[y2, x1 - 1]
                    - integral[y1 - 1, x2] + integral[y1 - 1, x1 - 1])

    def check_size(size, position_step):
        # 如果尺寸大于bounding box，直接返回失败
        if size > max_possible_size:
            return None, None

        # 只在bounding box范围内搜索
        start_y = max(0, y_min - size // 4)  # 稍微扩展搜索范围
        end_y = min(h - size + 1, y_max + size // 4)
        start_x = max(0, x_min - size // 4)
        end_x = min(w - size + 1, x_max + size // 4)

        # 创建候选点列表
        candidates = []

        # 第一阶段：快速筛选
        for i in range(start_y, end_y, position_step):
            for j in range(start_x, end_x, position_step):
                # 检查mask覆盖率
                mask_sum = get_area_sum(mask_integral, j, i, j + size - 1, i + size - 1)
                mask_ratio = mask_sum / (size * size)

                if mask_ratio < mask_threshold:
                    continue

                # 快速检查法向量变化
                normal_var_sum = get_area_sum(normal_var_integral, j, i, j + size - 1, i + size - 1)
                if normal_var_sum / (size * size) > normal_angle_threshold * 10:
                    continue

                # 使用积分图快速计算梯度均值
                grad_sum = get_area_sum(grad_integral, j, i, j + size - 1, i + size - 1)
                grad_avg = grad_sum / max(1, mask_sum)

                if grad_avg < score_threshold * 1.5:
                    candidates.append((i, j, grad_avg))

        # 第二阶段：详细检查
        candidates.sort(key=lambda x: x[2])  # 按梯度从小到大排序

        # 只检查前N个最有希望的候选区域
        max_candidates = min(20, len(candidates))
        for idx in range(max_candidates):
            i, j, _ = candidates[idx]

            square_mask = mask[i:i + size, j:j + size]
            valid_pixels = square_mask > 0

            if np.sum(valid_pixels) == 0:
                continue

            # 提取有效法向量和梯度
            square_normals = normal_map[i:i + size, j:j + size]
            square_grad = grad_magnitude[i:i + size, j:j + size]
            valid_normals = square_normals[valid_pixels]

            # 法向量一致性检查
            mean_normal = np.mean(valid_normals, axis=0)
            mean_normal_norm = np.sqrt(np.sum(mean_normal ** 2))

            if mean_normal_norm > 0:
                mean_normal = mean_normal / mean_normal_norm

                # 矢量化计算余弦相似度
                norms = np.sqrt(np.sum(valid_normals ** 2, axis=1, keepdims=True))
                valid_normals_normalized = valid_normals / np.maximum(norms, 1e-10)
                cos_angles = np.abs(np.sum(valid_normals_normalized * mean_normal, axis=1))

                # 高效计算角度偏差
                cos_angles = np.clip(cos_angles, -1, 1)
                angles = np.arccos(cos_angles)
                mean_angle = np.mean(angles)

                if mean_angle > normal_angle_threshold:
                    continue

            score = np.mean(square_grad[valid_pixels])
            if score < score_threshold:
                return (i, j, size), score

        return None, None

    # 6. 改进的二分搜索
    low = min_size
    high = max_size
    best_result = None
    best_score = float('inf')

    checked_sizes = {}

    while low <= high:
        mid = (low + high) // 2
        step = max(4, mid // 16)

        if mid not in checked_sizes:
            result, score = check_size(mid, step)
            checked_sizes[mid] = (result, score)
        else:
            result, score = checked_sizes[mid]

        if result is not None:
            best_result = result
            best_score = score
            low = mid + 1
        else:
            high = mid - 1

    return best_result, best_score


def plot_failed_search_result(obj_idx, search_image, mask, mat_category, output_path):
    # apply mask to search image
    search_image = search_image.copy()
    search_image[mask == 0] = 0
    fig, ax = plt.subplots(1, 2, figsize=(10, 5))
    ax[0].imshow(search_image)
    ax[0].set_title("Search Image")
    ax[0].axis('off')
    plt.suptitle(f"Object {obj_idx} - {mat_category} - No Crop Area")
    plt.tight_layout()
    # 只有这一行会实际产生输出
    output_file = os.path.join(output_path, f"material_{obj_idx}.png")
    plt.savefig(output_file)
    plt.close(fig)  # 关闭图表释放内存


def plot_search_result(obj_idx, crop_img, normal_map, slice_indices, material_image, best_mat_name, mat_category, output_path):
    # 修改为1x3布局
    fig, ax = plt.subplots(1, 3, figsize=(15, 5))

    # 第一个子图：裁剪图像
    if hasattr(crop_img[0], "permute"):  # torch.Tensor
        img = crop_img[0].permute(1, 2, 0).cpu().numpy()
    else:  # numpy
        img = crop_img[0]
    ax[0].imshow(img)
    ax[0].set_title("Cropped Image")
    ax[0].axis('off')

    # 第二个子图（新增）：normal_map和框
    ax[1].imshow(normal_map)
    ax[1].set_title("Normal Map with Selection")
    ax[1].axis('off')

    # 在normal_map上绘制slice_indices对应的框
    # 获取slice的起始和结束位置
    if isinstance(slice_indices[0], slice):
        y_start, y_end = slice_indices[0].start, slice_indices[0].stop
    else:  # 如果是范围元组
        y_start, y_end = slice_indices[0]

    if isinstance(slice_indices[1], slice):
        x_start, x_end = slice_indices[1].start, slice_indices[1].stop
    else:  # 如果是范围元组
        x_start, x_end = slice_indices[1]

    # 绘制矩形框
    rect = plt.Rectangle((x_start, y_start), x_end - x_start, y_end - y_start,
                         linewidth=2, edgecolor='r', facecolor='none')
    ax[1].add_patch(rect)

    # 第三个子图：材质图像
    ax[2].imshow(np.array(material_image))
    ax[2].set_title(f"Material: {best_mat_name}")
    ax[2].axis('off')

    plt.suptitle(f"Object {obj_idx} - {mat_category} - Matched")
    # 只有这一行会实际产生输出
    output_file = os.path.join(output_path, f"material_{obj_idx}.png")
    plt.savefig(output_file)
    plt.close(fig)  # 关闭图表释放内存


def pair_materials(obj_mask_paths, normal_path, roughness_image_path, metallic_image_path, output_path=None, device='cuda'):
    obj_idx_list = [int(os.path.basename(p).split('_')[1].split('.')[0]) for p in obj_mask_paths]
    search_image_roughness = np.array(Image.open(roughness_image_path).convert('RGB')) / 255.0
    search_image_metallic = np.array(Image.open(metallic_image_path).convert('RGB')) / 255.0

    resolution = search_image_roughness.shape[:2]
    # 读取mask和normal图像
    masks = [np.array(Image.open(p).convert('L').resize(resolution), dtype=np.float32) / 255.0 for p in obj_mask_paths]
    normal_map = np.array(Image.open(normal_path).convert('RGB').resize(resolution), dtype=np.float32) / 255.0

    td = TextureDescriptor(device)
    search_image_roughness_th = torch.tensor(search_image_roughness).permute(2, 0, 1).unsqueeze(0).float()
    search_image_metallic_th = torch.tensor(search_image_metallic).permute(2, 0, 1).unsqueeze(0).float()

    paired_materials = {}
    for obj_idx, mask in zip(obj_idx_list, masks):
        print(f"Pairing material for obj {obj_idx}...")
        slice_indices = prepare_crop_areas(mask, normal_map, resolution[0])
        if slice_indices is None:
            paired_materials[f"{obj_idx}"] = (None, None, None, None, None)

            if output_path is not None:
                plot_failed_search_result(obj_idx, normal_map, mask, None, output_path)
        else:
            crop_img_roughness = search_image_roughness_th[..., slice_indices[0], slice_indices[1]]
            crop_img_metallic = search_image_metallic_th[..., slice_indices[0], slice_indices[1]]
            crop_img_roughness = torch.nn.functional.interpolate(crop_img_roughness, (512, 512), mode='bilinear', align_corners=False)
            crop_img_metallic = torch.nn.functional.interpolate(crop_img_metallic, (512, 512), mode='bilinear', align_corners=False)
            best_mat_cate, best_mat_name, best_param_i = match_all_material_image(crop_img_roughness,
                                                                                  crop_img_metallic,
                                                                                  texture_descriptor=td, device=device)
            # cropped_search_image = (search_image[slice_indices[0], slice_indices[1], ...] * 255).astype(np.uint8)
            # result = search_text(cropped_search_image, nb_neighbors=5)

            paired_materials[f"{obj_idx}"] = (resolution, slice_indices, best_mat_cate, best_mat_name, best_param_i)

            if output_path is not None:
                material_image_path = os.path.join(MATERIAL_ROOT_DIR, best_mat_cate, best_mat_name, "sampled", "render",
                                                   f"params_{best_param_i}.png")
                material_image = Image.open(material_image_path).convert('RGB').resize((512, 512), Image.LANCZOS)
                plot_search_result(obj_idx, normal_map, normal_map, slice_indices, material_image,
                                   best_mat_name, best_mat_cate, output_path)

    return paired_materials