import math
import os
import torch
import torch.nn.functional as F
from diffmat import MaterialGraphTranslator as MGT
from utils import load_image_as_tensor


class ProceduralMaterial(torch.nn.Module):
    def __init__(self, sbs_file_path, mgt_res, external_input_path, ckp_path=None, albedo_image_path=None,
                 filter_mode='linear-mipmap-linear', init_scale=2.0, init_rotation=0.0):
        super(ProceduralMaterial, self).__init__()
        self.filter_mode = filter_mode
        self.texture_res = 2 ** mgt_res

        # 创建程序化图（MGT），并翻译、编译、设为可训练状态
        print(f"Translating Substance file: {sbs_file_path}")
        translator = MGT(sbs_file_path, res=mgt_res)
        self.graph = translator.translate(seed=0,
                                          use_alpha=False,
                                          normal_format='gl',
                                          gen_external_input=True,
                                          external_input_folder=external_input_path, device='cuda')
        self.graph.compile()
        self.graph.train()
        print(f"Substance file translated successfully.")

        # 如果提供了 checkpoint 路径，则加载初始参数
        if ckp_path is not None:
            ckp_data = torch.load(ckp_path)
            self.graph.set_parameters_from_tensor(ckp_data['param'])

        scale_param = torch.tensor(1.0, dtype=torch.float32).cuda()
        rotate_param = torch.tensor(0.0, dtype=torch.float32).cuda()
        albedo_gain = torch.ones(3, dtype=torch.float32).cuda()
        albedo_bias = torch.zeros(3, dtype=torch.float32).cuda()
        blend_factor = torch.tensor(1.0, dtype=torch.float32).cuda()

        self.scale_param = torch.nn.Parameter(scale_param)
        self.rotate_param = torch.nn.Parameter(rotate_param)
        self.albedo_gain = torch.nn.Parameter(albedo_gain)
        self.albedo_bias = torch.nn.Parameter(albedo_bias)
        self.blend_factor = torch.nn.Parameter(blend_factor)

        # 1,3, res, res
        default_normal = torch.tensor([0.5, 0.5, 1.0], dtype=torch.float32)
        self.default_normal = default_normal.view(1, 3, 1, 1).expand(1, 3, self.texture_res, self.texture_res).cuda()


        self.set_scale(init_scale)
        self.set_rotate(init_rotation)

        # 创建基础 UV 网格
        self._create_uv_grid()

        # 将albedo_image_path转换为张量，用于和输出的albedo进行混合
        if albedo_image_path is not None:
            print(f"Loading albedo image from: {albedo_image_path}")
            albedo_image = load_image_as_tensor(albedo_image_path, resolution=self.texture_res)[:, :3, :, :]
            self.albedo_image = albedo_image.cuda()
        else:
            self.albedo_image = None

    def _create_uv_grid(self):
        """创建标准化的 UV 坐标网格"""
        res = self.texture_res
        u = torch.linspace(0, 1, res, device='cuda')
        v = torch.linspace(0, 1, res, device='cuda')
        v_grid, u_grid = torch.meshgrid(v, u, indexing='ij')
        self.uv_grid = torch.stack([u_grid, v_grid], dim=-1).cuda()

    def get_scale(self):
        """获取当前的缩放值，范围在0.25~8.0之间平滑变化"""
        # 使用sigmoid函数将任意实数映射到(0,1)区间，再缩放到(0.25,8.0)区间
        # x/(1+|x|)是一个平滑的非线性函数，有较好的梯度特性
        # self.scale_param.data = torch.clamp(self.scale_param, -10.0, 10.0)  # 限制范围，避免数值不稳定
        normalized = torch.sigmoid(self.scale_param)  # 映射到(0,1)
        return 0.25 + normalized * 7.75  # 映射到(0.25,8.0)

    def get_rotate(self):
        """获取当前的旋转角度（弧度），在-π~π之间平滑过渡"""
        # 使用tanh函数将输入平滑映射到(-1,1)区间，再缩放到(-π,π)
        # self.rotate_param.data = torch.clamp(self.rotate_param, -10.0, 10.0)
        return torch.tanh(self.rotate_param) * math.pi

    def set_scale(self, scale_value):
        scale_value = torch.clamp(torch.tensor(scale_value, dtype=torch.float32), 0.25, 8.0)
        normalized = (scale_value - 0.25) / 7.75
        self.scale_param.data = torch.log(normalized / (1 - normalized))

    def set_rotate(self, rotate_value):
        pi = math.pi
        rotate_value = torch.clamp(torch.tensor(rotate_value, dtype=torch.float32), -pi, pi)
        normalized = rotate_value / pi
        eps = 1e-7  # 防止数值不稳定
        normalized = torch.clamp(normalized, -1.0 + eps, 1.0 - eps)
        self.rotate_param.data = 0.5 * torch.log((1 + normalized) / (1 - normalized))

    def _transform_uv(self):
        """应用UV变换并确保无限平铺"""
        # 获取缩放和旋转值
        scale = self.get_scale()
        rotate = self.get_rotate()

        # 预计算旋转三角函数值
        cos_r = torch.cos(rotate)
        sin_r = torch.sin(rotate)

        # 将UV坐标移到中心点(0.5, 0.5)
        u = self.uv_grid[..., 0] - 0.5
        v = self.uv_grid[..., 1] - 0.5

        # 应用旋转和缩放，一步到位
        u_rot = (u * cos_r - v * sin_r) * scale
        v_rot = (u * sin_r + v * cos_r) * scale

        # 移回原始坐标系
        u_new = u_rot + 0.5
        v_new = v_rot + 0.5

        # 梯度友好的平铺处理
        u_tiled = u_new - torch.floor(u_new)
        v_tiled = v_new - torch.floor(v_new)

        # 平滑过渡，避免在整数边界处的梯度突变
        epsilon = 1e-6
        u_tiled = torch.clamp(u_tiled, epsilon, 1.0 - epsilon)
        v_tiled = torch.clamp(v_tiled, epsilon, 1.0 - epsilon)

        # 合并为新的纹理坐标
        return torch.stack([u_tiled, v_tiled], dim=-1)

    def get_pbr_materials(self):
        # 生成基础纹理
        eval_graph = self.graph.evaluate_maps()

        # 处理原始纹理
        albedo_orig = eval_graph[0][:, :3, :, :].clip(0.0, 1.0)
        # normal_orig = eval_graph[1][:, :3, :, :].clip(0.0, 1.0)
        roughness_orig = eval_graph[2][:, :1, :, :].clip(0.0, 1.0)
        metallic_orig = eval_graph[3][:, :1, :, :].clip(0.0, 1.0)

        # 使用默认法线贴图 0,0,1
        normal_orig = self.default_normal.clone()

        # # 处理法线贴图 - 确保正确规范化
        # # 转换到[-1,1]范围
        # normal_centered = normal_orig * 2.0 - 1.0
        #
        # # 计算向量长度并归一化
        # norm = torch.sqrt(torch.sum(normal_centered ** 2, dim=1, keepdim=True))
        # normal_normalized = normal_centered / (norm + 1e-6)  # 防止除零
        #
        # # 转回[0,1]范围
        # normal_orig = normal_normalized * 0.5 + 0.5
        #
        # # 翻转Y通道 (OpenGL坐标系调整)
        # normal_orig[:, 1:2, :, :] = 1.0 - normal_orig[:, 1:2, :, :]

        # 如果提供了albedo_image_path，则将其与生成的漫反射纹理混合
        if self.albedo_image is not None:
            # 将self.blend_factor的值限制在0.0到1.0之间
            self.blend_factor.data = torch.clamp(self.blend_factor, 0.0, 1.0)
            albedo_orig = self.albedo_image * self.blend_factor + albedo_orig * (1 - self.blend_factor)

        gain = self.albedo_gain.view(1, 3, 1, 1).expand(albedo_orig.shape)
        bias = self.albedo_bias.view(1, 3, 1, 1).expand(albedo_orig.shape)
        adjusted_albedo = (albedo_orig * gain + bias).clip(0.0, 1.0)

        # 合并所有纹理通道到一个大张量
        pbr_materials = torch.cat([
            adjusted_albedo,
            normal_orig,
            roughness_orig,
            metallic_orig
        ], dim=1)

        return pbr_materials

    def evaluate_maps(self):
        """调用程序化图的 evaluate_maps() 获得所有纹理图输出，并应用 UV 变换"""
        # 计算变换后的 UV 坐标
        transformed_uv = self._transform_uv()
        # 生成基础纹理
        combined_textures = self.get_pbr_materials()
        # 使用grid_sample采样原始纹理
        transformed_textures = F.grid_sample(
            combined_textures,
            transformed_uv.unsqueeze(0),
            mode='bilinear',
            padding_mode='border',  # 这里无所谓，因为坐标已经被限制在[0,1]范围内
            align_corners=True
        )

        # 分离各个通道
        albedo = transformed_textures[0, :3, :, :].permute(1, 2, 0)
        normal = transformed_textures[0, 3:6, :, :].permute(1, 2, 0)
        roughness = transformed_textures[0, 6:7, :, :].permute(1, 2, 0)
        metallic = transformed_textures[0, 7:8, :, :].permute(1, 2, 0)

        # 构建结果字典
        maps = {
            'albedo': albedo,
            'normal': normal,
            'roughness': roughness,
            'metallic': metallic
        }
        return maps

    def get_parameters(self):
        # 获取所有参数
        graph_params = list(self.graph.parameters(-1))
        albedo_params = [self.albedo_gain, self.albedo_bias]
        uv_params = [self.scale_param, self.rotate_param]

        # 构建参数组
        param_groups = [
            {'params': graph_params, 'lr_mult': 0.1, 'name': 'graph'},
            {'params': albedo_params, 'lr_mult': 0.1, 'name': 'albedo'},
            {'params': self.blend_factor, 'lr_mult': 0.1, 'name': 'blend_factor'},
            {'params': uv_params, 'lr_mult': 0.1, 'name': 'uv'}
        ]

        return param_groups


class UniformMaterial(torch.nn.Module):
    """
    用于优化的均匀材质，只包含一个颜色属性。
    初始化参数:
      color: 初始颜色，格式为 [r, g, b]
      min_max: （可选）用于 clamp 操作的下界和上界，格式为 ([min0, min1, ...], [max0, max1, ...])
    """
    def __init__(self, base_color=[0.5, 0.5, 0.5], roughness=0.5, metallic=0.0, min_max=None, albedo_image_path=None):
        super(UniformMaterial, self).__init__()
        base_color = torch.tensor(base_color, dtype=torch.float32).cuda()
        roughness = torch.tensor(roughness, dtype=torch.float32).cuda()
        metallic = torch.tensor(metallic, dtype=torch.float32).cuda()
        self.base_color = torch.nn.Parameter(base_color)
        self.roughness = torch.nn.Parameter(roughness)
        self.metallic = torch.nn.Parameter(metallic)
        self.blend_factor = torch.nn.Parameter(torch.tensor(1.0, dtype=torch.float32).cuda())

        # 将albedo_image_path转换为张量，用于和输出的albedo进行混合
        if albedo_image_path is not None and os.path.exists(albedo_image_path):
            print(f"Loading albedo image from: {albedo_image_path}")
            # 使用默认分辨率1024，因为UniformMaterial没有texture_res属性
            albedo_image = load_image_as_tensor(albedo_image_path, resolution=1024)[:, :3, :, :]
            self.albedo_image = albedo_image.cuda()
        else:
            self.albedo_image = None

    def evaluate_maps(self):
        # 确保颜色在物理合理的范围内
        base_color = self.base_color.clamp(0.0, 1.0)
        roughness = self.roughness.clamp(0.01, 0.99)
        metallic = self.metallic.clamp(0.0, 1.0)

        # 如果提供了albedo_image_path，则将其与生成的漫反射纹理混合
        if self.albedo_image is not None:
            # 将self.blend_factor的值限制在0.0到1.0之间
            self.blend_factor.data = torch.clamp(self.blend_factor, 0.0, 1.0)
            # 将base_color扩展为与albedo_image相同的形状进行混合
            base_color_expanded = base_color.view(1, 3, 1, 1).expand_as(self.albedo_image)
            mixed_albedo = self.albedo_image * self.blend_factor + base_color_expanded * (1 - self.blend_factor)
            # 取平均值作为最终的base_color
            base_color = mixed_albedo.mean(dim=(2, 3)).squeeze(0)

        return {
            'albedo': base_color,
            'roughness': roughness,
            'metallic': metallic
        }

    def get_parameters(self):
        # 获取所有参数
        params = [self.base_color, self.roughness, self.metallic]
        # 构建参数组字典
        param_groups = [
            {'params': params, 'lr_mult': 1.0},
            {'params': self.blend_factor, 'lr_mult': 0.1, 'name': 'blend_factor'},
        ]
        return param_groups