import os
import re
import cv2
import json
import torch
import numpy as np
import mitsuba as mi
mi.set_variant('cuda_ad_rgb')
import drjit as dr
from PIL import Image
from torchvision import transforms
from torchvision.utils import save_image


def load_image_as_tensor(image_path, resolution=None, convert_gray=False):
    """
    Load an image and convert it to a tensor.

    Args:
        image_path: Path to the image
        resolution: Target resolution
        convert_gray: Whether to convert to grayscale

    Returns:
        torch.Tensor: Image tensor
    """
    image = Image.open(image_path)
    if convert_gray and image.mode != 'L':
        image = image.convert('L')

    if resolution is not None:
        if isinstance(resolution, int):
            image = image.resize((resolution, resolution), Image.LANCZOS)
        elif isinstance(resolution, tuple) or isinstance(resolution, list):
            assert len(resolution) == 2, "Resolution must be a tuple or list of 2 integers"
            image = image.resize(resolution, Image.LANCZOS)
        else:
            raise ValueError(f"Invalid resolution format: {resolution}")

    # Process image tensor
    if image.mode == 'L':
        tensor = transforms.ToTensor()(image)
        tensor = tensor.unsqueeze(0)
    else:
        tensor = transforms.ToTensor()(image)
        tensor = tensor.unsqueeze(0)
    return tensor


def save_masked_image(mask_dir, output_dir, image):
    """
    Apply masks to an image and save the results.

    Args:
        mask_dir: Directory containing mask files
        output_dir: Directory to save the masked images
        image: Image to apply masks to
    """
    if isinstance(image, mi.TensorXf):
        image = image.torch()
    # Get all mask_{idx}.png files in mask_dir
    mask_files = [f for f in os.listdir(mask_dir) if f.startswith("mask_") and f.endswith(".png")]
    for mask_file in mask_files:
        mask_path = os.path.join(mask_dir, mask_file)
        mask = Image.open(mask_path).convert("L")
        mask = mask.resize((image.shape[0], image.shape[1]), Image.LANCZOS)
        mask_np = np.array(mask) / 255.0
        mask_tensor = torch.from_numpy(mask_np).unsqueeze(-1).to(image.device)

        masked_image = image * mask_tensor
        masked_image = masked_image.permute(2, 0, 1).unsqueeze(0)
        save_image(masked_image, os.path.join(output_dir, f"img_{mask_file}"))


def get_obj_files_list(obj_file_path):
    """
    根据输入路径获取OBJ文件列表
    - 如果是文件，直接返回包含该文件的列表
    - 如果是目录，获取所有符合条件的OBJ文件并按索引排序
    """
    if os.path.isfile(obj_file_path):
        return [obj_file_path]
    elif os.path.isdir(obj_file_path):
        obj_files = []
        pattern = re.compile(r'obj_(\d+)\.obj$')

        for filename in os.listdir(obj_file_path):
            match = pattern.match(filename)
            if match:
                idx = int(match.group(1))
                obj_files.append((idx, os.path.join(obj_file_path, filename)))

        # 按索引排序
        obj_files.sort(key=lambda x: x[0])
        return [f[1] for f in obj_files]
    else:
        raise ValueError("OBJ文件路径无效")


def clean_tensor_for_gamma(tensor):
    """
    Clean tensor values for safe gamma correction.

    Args:
        tensor: Input tensor

    Returns:
        torch.Tensor: Cleaned tensor
    """
    eps = 1e-8
    tensor = tensor.clamp(eps, 1.0)
    tensor = torch.nan_to_num(tensor, nan=eps, posinf=1.0, neginf=eps)
    return tensor


def srgb_to_linear(image):
    image = clean_tensor_for_gamma(image)
    return torch.pow(image, 2.2)


def linear_to_srgb(image):
    """
    Convert from linear to sRGB space.

    Args:
        image: Linear space image

    Returns:
        torch.Tensor: sRGB image
    """
    image = clean_tensor_for_gamma(image)
    return torch.pow(image, 1.0 / 2.2)


def linear_to_srgb_drjit(image):
    """
    Convert from linear to sRGB space.

    Args:
        image: Linear space image

    Returns:
        torch.Tensor: sRGB image
    """
    image = dr.clamp(image, 0.0, 1.0)
    return image ** (1.0 / 2.2)


def load_ldr_image(image_path, from_srgb=False, clamp=False, normalize=False):
    # Load png or jpg image
    image = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
    image = torch.from_numpy(image.astype(np.float32) / 255.0)  # (h, w, c)
    image[~torch.isfinite(image)] = 0
    if from_srgb:
        # Convert from sRGB to linear RGB
        image = image**2.2
    if clamp:
        image = torch.clamp(image, min=0.0, max=1.0)
    if normalize:
        # Normalize to [-1, 1]
        image = image * 2.0 - 1.0
        image = torch.nn.functional.normalize(image, dim=-1, eps=1e-6)
    return image.permute(2, 0, 1)  # returns (c, h, w)


def load_exr_image(image_path, tonemaping=False, clamp=False, normalize=False):
    image = cv2.cvtColor(cv2.imread(image_path, -1), cv2.COLOR_BGR2RGB)
    image = torch.from_numpy(image.astype("float32"))  # (h, w, c)
    image[~torch.isfinite(image)] = 0
    if tonemaping:
        # Exposure adjuestment
        image_Yxy = convert_rgb_2_Yxy(image)
        lum = (
            image[:, :, 0:1] * 0.2125
            + image[:, :, 1:2] * 0.7154
            + image[:, :, 2:3] * 0.0721
        )
        lum = torch.log(torch.clamp(lum, min=1e-6))
        lum_mean = torch.exp(torch.mean(lum))
        lp = image_Yxy[:, :, 0:1] * 0.18 / torch.clamp(lum_mean, min=1e-6)
        image_Yxy[:, :, 0:1] = lp
        image = convert_Yxy_2_rgb(image_Yxy)
    if clamp:
        image = torch.clamp(image, min=0.0, max=1.0)
    if normalize:
        image = torch.nn.functional.normalize(image, dim=-1, eps=1e-6)
    return image.permute(2, 0, 1)  # returns (c, h, w)


def convert_rgb_2_XYZ(rgb):
    # Reference: https://web.archive.org/web/20191027010220/http://www.brucelindbloom.com/index.html?Eqn_RGB_XYZ_Matrix.html
    # rgb: (h, w, 3)
    # XYZ: (h, w, 3)
    XYZ = torch.ones_like(rgb)
    XYZ[:, :, 0] = (
        0.4124564 * rgb[:, :, 0] + 0.3575761 * rgb[:, :, 1] + 0.1804375 * rgb[:, :, 2]
    )
    XYZ[:, :, 1] = (
        0.2126729 * rgb[:, :, 0] + 0.7151522 * rgb[:, :, 1] + 0.0721750 * rgb[:, :, 2]
    )
    XYZ[:, :, 2] = (
        0.0193339 * rgb[:, :, 0] + 0.1191920 * rgb[:, :, 1] + 0.9503041 * rgb[:, :, 2]
    )
    return XYZ


def convert_XYZ_2_Yxy(XYZ):
    # XYZ: (h, w, 3)
    # Yxy: (h, w, 3)
    Yxy = torch.ones_like(XYZ)
    Yxy[:, :, 0] = XYZ[:, :, 1]
    sum = torch.sum(XYZ, dim=2)
    inv_sum = 1.0 / torch.clamp(sum, min=1e-4)
    Yxy[:, :, 1] = XYZ[:, :, 0] * inv_sum
    Yxy[:, :, 2] = XYZ[:, :, 1] * inv_sum
    return Yxy


def convert_rgb_2_Yxy(rgb):
    # rgb: (h, w, 3)
    # Yxy: (h, w, 3)
    return convert_XYZ_2_Yxy(convert_rgb_2_XYZ(rgb))


def convert_XYZ_2_rgb(XYZ):
    # XYZ: (h, w, 3)
    # rgb: (h, w, 3)
    rgb = torch.ones_like(XYZ)
    rgb[:, :, 0] = (
        3.2404542 * XYZ[:, :, 0] - 1.5371385 * XYZ[:, :, 1] - 0.4985314 * XYZ[:, :, 2]
    )
    rgb[:, :, 1] = (
        -0.9692660 * XYZ[:, :, 0] + 1.8760108 * XYZ[:, :, 1] + 0.0415560 * XYZ[:, :, 2]
    )
    rgb[:, :, 2] = (
        0.0556434 * XYZ[:, :, 0] - 0.2040259 * XYZ[:, :, 1] + 1.0572252 * XYZ[:, :, 2]
    )
    return rgb


def convert_Yxy_2_XYZ(Yxy):
    # Yxy: (h, w, 3)
    # XYZ: (h, w, 3)
    XYZ = torch.ones_like(Yxy)
    XYZ[:, :, 0] = Yxy[:, :, 1] / torch.clamp(Yxy[:, :, 2], min=1e-6) * Yxy[:, :, 0]
    XYZ[:, :, 1] = Yxy[:, :, 0]
    XYZ[:, :, 2] = (
        (1.0 - Yxy[:, :, 1] - Yxy[:, :, 2])
        / torch.clamp(Yxy[:, :, 2], min=1e-4)
        * Yxy[:, :, 0]
    )
    return XYZ


def convert_Yxy_2_rgb(Yxy):
    # Yxy: (h, w, 3)
    # rgb: (h, w, 3)
    return convert_XYZ_2_rgb(convert_Yxy_2_XYZ(Yxy))

def slice_to_tuple(s):
    return (s.start, s.stop, s.step)

def tuple_to_slice(t):
    # 确保t是一个包含slice范围信息的列表或元组
    if len(t) == 3:  # [start, stop, step]
        start, stop, step = t
        return slice(start, stop, step)
    elif len(t) == 2:  # [start, stop]
        start, stop = t
        return slice(start, stop)
    else:
        # 处理JSON中的特殊格式，如[start, stop, null]
        # 根据你的JSON示例，似乎有[910, 1357, null]这样的数据
        if t[2] is None:
            return slice(t[0], t[1])
        return slice(*[x for x in t if x is not None])


class NumpyEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        elif isinstance(obj, np.floating):
            return float(obj)
        elif isinstance(obj, np.ndarray):
            return obj.tolist()
        return json.JSONEncoder.default(self, obj)

def load_material_pair_json(load_path):
    # 读取数据
    with open(load_path, 'r', encoding='utf-8') as f:
        loaded_data = json.load(f)

    # 恢复处理
    restored_data = {}
    for key, value in loaded_data.items():
        if value is None:
            restored_data[key] = None
        elif value[1] is None:
            restored_data[key] = (None, None, value[2], None, None)
        else:
            resolution, slices_tuple, mat_category, mat_name, param_i = value
            # 将元组转回slice对象
            resolution = tuple(resolution)
            slices_tuple = tuple(tuple_to_slice(s) for s in slices_tuple)
            restored_data[key] = (resolution, slices_tuple, mat_category, mat_name, param_i)

    return restored_data

def save_material_pair_json(paired_materials, save_path):
    # 预处理数据
    processed_data = {}
    for key, value in paired_materials.items():
        if value is None:
            processed_data[key] = None
        elif value[1] is None:
            processed_data[key] = (None, None, value[2], None, None)
        else:
            resolution, slices, mat_category, mat_name, param_i = value
            # 将两个slice对象转换为元组
            resolution = tuple(resolution)
            slices = tuple(slice_to_tuple(s) for s in slices)
            processed_data[key] = (resolution, slices, mat_category, mat_name, param_i)

    # 使用自定义编码器保存数据
    with open(save_path, "w") as f:
        json.dump(processed_data, f, cls=NumpyEncoder)