import os

import numpy as np
import torch
from torch import nn
from plyfile import PlyData, PlyElement

from src.utils.sh import RGB2SH
from simple_knn._C import distCUDA2
from src.utils.general import inverse_sigmoid, get_expon_lr_func, build_rotation, strip_symmetric, \
    build_scaling_rotation
from src.utils.point import BasicPointCloud

import numpy as np
import matplotlib.pyplot as plt

# Let's define a custom activation function
def custom_activation(x):
    """
    A custom activation function designed to produce a double peak in the output distribution.
    The function is a piecewise combination of two sigmoid functions.
    """
    # Parameters of the sigmoid functions
    alpha = 10  # Controls the steepness of the curve
    beta = 0.5  # Controls the position of the transition
    
    # Applying two sigmoids and scaling them
    sigmoid1 = 1 / (1 + torch.exp(-alpha * (x - beta)))
    sigmoid2 = 1 / (1 + torch.exp(-alpha * (x - (1 - beta))))
    
    # Combining the two sigmoids
    return sigmoid1 * (1 - sigmoid2)

def fibonacci_sphere(n_pts):
    # 使用Fibonacci序列方法来近似均匀分布点
    indices = torch.arange(0, n_pts, dtype=torch.float32)
    phi = torch.acos(1 - 2*(indices+0.5)/n_pts)  # phi是从0到pi
    golden_ratio = (1 + 5**0.5) / 2
    theta = 2 * torch.pi * indices / golden_ratio  # theta是从0到2pi

    # 将球坐标转换为笛卡尔坐标
    x = torch.sin(phi) * torch.cos(theta)
    y = torch.sin(phi) * torch.sin(theta)
    z = torch.cos(phi)

    # 组合x, y, z到一个tensor中
    points = torch.stack((x, y, z), dim=1)

    return points

class GaussianModel(nn.Module):
    @staticmethod
    def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation):
        L = build_scaling_rotation(scaling_modifier * scaling, rotation)
        actual_covariance = L @ L.transpose(1, 2)
        symm = strip_symmetric(actual_covariance)
        return symm

    def setup_functions(self):

        # self.scaling_activation = torch.exp
        # self.scaling_activation = nn.Softplus(beta=1.)
        self.scaling_activation = nn.Sigmoid()
        self.scaling_inverse_activation = torch.log

        self.covariance_activation = GaussianModel.build_covariance_from_scaling_rotation

        self.opacity_activation = torch.sigmoid
        self.inverse_opacity_activation = inverse_sigmoid

        self.rotation_activation = torch.nn.functional.normalize
        self.color_activation = torch.sigmoid


    def __init__(self, n_pts, n_dim, query_transformer, fold_transformer, head_config: dict):
        super().__init__()
        self.heads = nn.ModuleDict()
        # uniform sphere init
        anchors = fibonacci_sphere(n_pts)
        self.anchors = nn.Parameter(anchors, requires_grad=False)
        self.feature = nn.Parameter(torch.randn(n_pts, n_dim), requires_grad=True)
        self.query_transformer = query_transformer
        self.fold_transformer = fold_transformer
        self.max_scale = (1 / n_pts) ** (1 / 3) * 1.6
        for name, config in head_config.items():
            self.heads[name] = nn.Sequential(
                nn.Linear(config["in_dim"], config["hidden_dim"]),
                # nn.ReLU(),
                nn.Linear(config["hidden_dim"], config["hidden_dim"]),
                # nn.ReLU(),
                nn.Linear(config["hidden_dim"], config["out_dim"])
            )
        self.setup_functions()

    def set_use_memory_efficient_attention_xformers(
        self, valid: bool, attention_op=None
    ) -> None:
        def fn_recursive_set_mem_eff(module: torch.nn.Module):
            if hasattr(module, "set_use_memory_efficient_attention_xformers"):
                module.set_use_memory_efficient_attention_xformers(valid, attention_op)

            for child in module.children():
                fn_recursive_set_mem_eff(child)

        for module in self.children():
            if isinstance(module, torch.nn.Module):
                fn_recursive_set_mem_eff(module)

    def get_scaling(self, feature):
        # return self.scaling_activation(self._scaling)
        return self.scaling_activation(self.heads["scaling"](feature)) * self.max_scale

    def get_rotation(self, feature):
        # return self.rotation_activation(self._rotation)
        return self.rotation_activation(self.heads["rotation"](feature))        

    def get_xyz(self, feature):
        # return self._xyz
        return self.heads["xyz"](feature) + self.anchors

    def get_opacity(self, feature):
        # return self.opacity_activation(self._opacity)
        return self.opacity_activation(self.heads["opacity"](feature))

    def get_color(self, feature):
        # return self.color_activation(self._color)
        # return self.color_activation(self.heads["color"](feature))
        
        return self.heads["color"](feature).reshape(-1, 16, 3)
    
    def get_covariance(self, scaling_modifier=1):
        return self.covariance_activation(self.get_scaling, scaling_modifier, self._rotation)
    
    def construct_list_of_attributes(self):
        l = ['x', 'y', 'z', 'nx', 'ny', 'nz']
        l.append('opacity')
        for i in range(3):
            l.append('scale_{}'.format(i))
        for i in range(4):
            l.append('rot_{}'.format(i))
        return l

