import os

import numpy as np
import torch
from torch import nn
from plyfile import PlyData, PlyElement

from src.utils.sh import RGB2SH
from simple_knn._C import distCUDA2
from src.utils.general import inverse_sigmoid, get_expon_lr_func, build_rotation, strip_symmetric, \
    build_scaling_rotation
from src.utils.point import BasicPointCloud
from einops import rearrange, repeat
import numpy as np
import matplotlib.pyplot as plt
from pytorch3d.renderer import HarmonicEmbedding

# Let's define a custom activation function
def custom_activation(x):
    """
    A custom activation function designed to produce a double peak in the output distribution.
    The function is a piecewise combination of two sigmoid functions.
    """
    # Parameters of the sigmoid functions
    alpha = 10  # Controls the steepness of the curve
    beta = 0.5  # Controls the position of the transition
    
    # Applying two sigmoids and scaling them
    sigmoid1 = 1 / (1 + torch.exp(-alpha * (x - beta)))
    sigmoid2 = 1 / (1 + torch.exp(-alpha * (x - (1 - beta))))
    
    # Combining the two sigmoids
    return sigmoid1 * (1 - sigmoid2)

def fibonacci_sphere(n_pts):
    # 使用Fibonacci序列方法来近似均匀分布点
    indices = torch.arange(0, n_pts, dtype=torch.float32)
    phi = torch.acos(1 - 2*(indices+0.5)/n_pts)  # phi是从0到pi
    golden_ratio = (1 + 5**0.5) / 2
    theta = 2 * torch.pi * indices / golden_ratio  # theta是从0到2pi

    # 将球坐标转换为笛卡尔坐标
    x = torch.sin(phi) * torch.cos(theta)
    y = torch.sin(phi) * torch.sin(theta)
    z = torch.cos(phi)

    # 组合x, y, z到一个tensor中
    points = torch.stack((x, y, z), dim=1)

    return points

class GaussianModel(nn.Module):
    @staticmethod
    def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation):
        L = build_scaling_rotation(scaling_modifier * scaling, rotation)
        actual_covariance = L @ L.transpose(1, 2)
        symm = strip_symmetric(actual_covariance)
        return symm

    def setup_functions(self):

        # self.scaling_activation = torch.exp
        self.scaling_activation = nn.Softplus(beta=1.)
        self.scaling_inverse_activation = torch.log

        self.covariance_activation = GaussianModel.build_covariance_from_scaling_rotation

        self.opacity_activation = torch.sigmoid
        self.inverse_opacity_activation = inverse_sigmoid

        self.rotation_activation = torch.nn.functional.normalize
        self.color_activation = torch.sigmoid


    def __init__(self, n_pts, n_dim, query_transformer, fold_transformer, head_config: dict):
        super().__init__()
        self.heads = nn.ModuleDict()
        # uniform sphere init
        anchors = fibonacci_sphere(n_pts)
        self.anchors = nn.Parameter(anchors, requires_grad=False)
        self.position_embedding = HarmonicEmbedding()
        input_dim = self.position_embedding.get_output_dim()
        self.pos_mapping = nn.Sequential(
            nn.Linear(input_dim, 4*n_dim),
            # nn.ReLU(),
            nn.Linear(4*n_dim, n_dim),
            # nn.ReLU()
        )
        self.position_embed = nn.Parameter(torch.zeros(n_pts, n_dim), requires_grad=True)
        self.query_transformer = query_transformer
        self.fold_transformer = fold_transformer
        # self.tempurature = nn.Parameter(torch.tensor(50.0), requires_grad=True)
        self.tempurature_list = nn.Parameter(torch.tensor([50.0, 100.0]), requires_grad=True)
        self.transformers = nn.ModuleList([query_transformer])
        for name, config in head_config.items():
            self.heads[name] = nn.Sequential(
                nn.Linear(config["in_dim"], config["hidden_dim"]),
                # nn.ReLU(),
                nn.Linear(config["hidden_dim"], config["hidden_dim"]),
                # nn.ReLU(),
                nn.Linear(config["hidden_dim"], config["out_dim"])
            )
        
        self.setup_functions()

    def set_use_memory_efficient_attention_xformers(
        self, valid: bool, attention_op=None
    ) -> None:
        def fn_recursive_set_mem_eff(module: torch.nn.Module):
            if hasattr(module, "set_use_memory_efficient_attention_xformers"):
                module.set_use_memory_efficient_attention_xformers(valid, attention_op)

            for child in module.children():
                fn_recursive_set_mem_eff(child)

        for module in self.children():
            if isinstance(module, torch.nn.Module):
                fn_recursive_set_mem_eff(module)

    def get_scaling(self, feature, tempurature):
        # return self.scaling_activation(self._scaling)
        return self.scaling_activation(self.heads["scaling"](feature)) / tempurature

    def get_rotation(self, feature):
        # return self.rotation_activation(self._rotation)
        return self.rotation_activation(self.heads["rotation"](feature))        

    def get_xyz(self, feature, anchors = None):
        # return self._xyz
        if anchors is None:
            anchors = self.anchors
        return self.heads["xyz"](feature) + anchors

    def get_opacity(self, feature):
        # return self.opacity_activation(self._opacity)
        return self.opacity_activation(self.heads["opacity"](feature))

    def get_color(self, feature):
        # return self.color_activation(self._color)
        # return self.color_activation(self.heads["color"](feature))
        return self.heads["color"](feature).reshape(-1, 16, 3)

    def forward(self, renderer=None, background_color=None, camera=None, images=None):
        outputs_list = []
        anchors = self.anchors
        latent_token = None
        for layer_idx, transformer in enumerate(self.transformers):
            pos_embed = self.position_embedding(anchors)
            pos_embed = self.pos_mapping(pos_embed).unsqueeze(0)
            if latent_token is None:
                latent_token = torch.zeros_like(pos_embed) + pos_embed
            meta = {
                'pos_embed': pos_embed
            }
            latent_token = transformer(latent_token, meta=meta).sample
            xyz = self.get_xyz(latent_token, anchors).squeeze(0)
            opacity = self.get_opacity(latent_token).squeeze(0)
            scales = self.get_scaling(latent_token, self.tempurature_list[layer_idx]).squeeze(0)
            rotations = self.get_rotation(latent_token).squeeze(0)
            shs = self.get_color(latent_token).squeeze(0)
            if renderer is not None:
                outputs = renderer(
                    camera,
                    xyz,
                    opacity,
                    scales,
                    rotations,
                    shs,
                    bg_color = background_color.to(camera.R.device)
                )
            else:
                outputs = {
                    'means3D': xyz,
                    'opacity': opacity,
                    'scale': scales,
                    'rotation': rotations,
                    'color': shs,
                }

            if layer_idx < len(self.transformers) - 1:
                split_xyz = self.split(xyz, scales, rotations, M=2)
                latent_token = repeat(latent_token, 'b n d -> b (m n) d', m=2)
                anchors = split_xyz

            outputs_list.append(outputs)

        return outputs_list

    def get_covariance(self, scaling_modifier=1):
        return self.covariance_activation(self.get_scaling, scaling_modifier, self._rotation)

    def split(self, xyz, scales, rotation, M=2):
        # xyz : N 3
        means = torch.zeros_like(repeat(xyz, 'n d -> (m n) d', m=M))
        stds = repeat(scales, 'n d -> (m n) d', m=M)
        rots = build_rotation(rotation) # N 3 3
        rots = repeat(rots, 'n d1 d2 -> (m n) d1 d2', m=M)
        samples = torch.randn_like(means) * stds
        samples = torch.einsum('... n i j,... n j->... n i', rots, samples) + means
        return samples
    
    
    def construct_list_of_attributes(self):
        l = ['x', 'y', 'z', 'nx', 'ny', 'nz']
        l.append('opacity')
        for i in range(3):
            l.append('scale_{}'.format(i))
        for i in range(4):
            l.append('rot_{}'.format(i))
        return l

