import os

import numpy as np
import torch
from torch import nn
from plyfile import PlyData, PlyElement

from src.utils.sh import RGB2SH
from simple_knn._C import distCUDA2
from src.utils.general import inverse_sigmoid, get_expon_lr_func, build_rotation, strip_symmetric, \
    build_scaling_rotation
from src.utils.point import BasicPointCloud
import math
from einops import rearrange, repeat
import numpy as np
import matplotlib.pyplot as plt
from pytorch3d.renderer import HarmonicEmbedding
from src.models.gs.position_encoding import build_position_encoding

def gen_sineembed_for_position(pos_tensor, d_model=256):
    # n_query, bs, _ = pos_tensor.size()
    # sineembed_tensor = torch.zeros(n_query, bs, 256)
    scale = 2 * math.pi
    dim_t = torch.arange(d_model // 2, dtype=torch.float32, device=pos_tensor.device)
    dim_t = 10000 ** (2 * (dim_t // 2) / (d_model // 2))
    x_embed = pos_tensor[:, :, 0] * scale
    y_embed = pos_tensor[:, :, 1] * scale
    pos_x = x_embed[:, :, None] / dim_t
    pos_y = y_embed[:, :, None] / dim_t
    pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2)
    pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2)
    if pos_tensor.size(-1) == 2:
        pos = torch.cat((pos_y, pos_x), dim=2)
    elif pos_tensor.size(-1) == 4:
        w_embed = pos_tensor[:, :, 2] * scale
        pos_w = w_embed[:, :, None] / dim_t
        pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2)

        h_embed = pos_tensor[:, :, 3] * scale
        pos_h = h_embed[:, :, None] / dim_t
        pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2)

        pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2)
    else:
        raise ValueError("Unknown pos_tensor shape(-1):{}".format(pos_tensor.size(-1)))
    return pos

# Let's define a custom activation function
def custom_activation(x):
    """
    A custom activation function designed to produce a double peak in the output distribution.
    The function is a piecewise combination of two sigmoid functions.
    """
    # Parameters of the sigmoid functions
    alpha = 10  # Controls the steepness of the curve
    beta = 0.5  # Controls the position of the transition
    
    # Applying two sigmoids and scaling them
    sigmoid1 = 1 / (1 + torch.exp(-alpha * (x - beta)))
    sigmoid2 = 1 / (1 + torch.exp(-alpha * (x - (1 - beta))))
    
    # Combining the two sigmoids
    return sigmoid1 * (1 - sigmoid2)

def fibonacci_sphere(n_pts):
    # 使用Fibonacci序列方法来近似均匀分布点
    indices = torch.arange(0, n_pts, dtype=torch.float32)
    phi = torch.acos(1 - 2*(indices+0.5)/n_pts)  # phi是从0到pi
    golden_ratio = (1 + 5**0.5) / 2
    theta = 2 * torch.pi * indices / golden_ratio  # theta是从0到2pi

    # 将球坐标转换为笛卡尔坐标
    x = torch.sin(phi) * torch.cos(theta)
    y = torch.sin(phi) * torch.sin(theta)
    z = torch.cos(phi)

    # 组合x, y, z到一个tensor中
    points = torch.stack((x, y, z), dim=1)

    return points

def initialize_points_in_grid(n):
    k = round(n ** (1/3))
    actual_n = k ** 3
    intervals = torch.linspace(-0.8, 0.8, steps=k)
    points = torch.tensor([[x, y, z] for x in intervals for y in intervals for z in intervals])
    return points, actual_n


def list2dict(list_of_dicts):
    keys = list_of_dicts[0].keys()
    stacked_dict = {key: torch.stack([d[key] for d in list_of_dicts]) for key in keys}
    return stacked_dict


class GaussianModel(nn.Module):
    @staticmethod
    def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation):
        L = build_scaling_rotation(scaling_modifier * scaling, rotation)
        actual_covariance = L @ L.transpose(1, 2)
        symm = strip_symmetric(actual_covariance)
        return symm

    def setup_functions(self):

        # self.scaling_activation = torch.exp
        # self.scaling_activation = nn.Softplus(beta=1.)
        self.scaling_activation = nn.Sigmoid()
        # self.scaling_inverse_activation = torch.log

        self.covariance_activation = GaussianModel.build_covariance_from_scaling_rotation

        self.opacity_activation = torch.sigmoid
        self.inverse_opacity_activation = inverse_sigmoid

        self.rotation_activation = torch.nn.functional.normalize
        self.color_activation = torch.sigmoid
        self.inverse_xyz_activation = inverse_sigmoid
        self.xyz_activation = torch.sigmoid


    def __init__(self, n_pts, n_dim, query_transformer, fold_transformer, 
                 encoder, head_config: dict):
        super().__init__()
        self.heads = nn.ModuleDict()
        # uniform sphere init
        anchors = fibonacci_sphere(n_pts)
        # anchors, _ = initialize_points_in_grid(n_pts)

        self.anchors = nn.Parameter(anchors, requires_grad=True)

        self.position_embedding = HarmonicEmbedding()
        input_dim = self.position_embedding.get_output_dim()
        
        self.pos_mapping = nn.Sequential(
            nn.Linear(input_dim, query_transformer.inner_dim),
            # nn.ReLU(),
            nn.Linear(query_transformer.inner_dim, query_transformer.inner_dim),
            # nn.ReLU()
        )
        
        self.position_embedding_hw = build_position_encoding(query_transformer.inner_dim)
        distances = distCUDA2(anchors.float().cuda())
        max_scale = distances.max() * 15
        self.max_scale = nn.Parameter(torch.tensor(max_scale), requires_grad=True)
        
        self.encoder = encoder

        self.query_transformer = query_transformer
        self.fold_transformer = fold_transformer

        # self.tempurature = nn.Parameter(torch.tensor(50.0), requires_grad=True)
        self.tempurature_list = nn.Parameter(torch.tensor([50.0, 100.0]), requires_grad=True)
        # self.transformers = nn.ModuleList([query_transformer])
        self.transformers = nn.ModuleList([query_transformer, fold_transformer])
        self.n_dim = n_dim
        self.n_pts = n_pts

        for name, config in head_config.items():
            self.heads[name] = nn.Sequential(
                nn.Linear(config["in_dim"], config["hidden_dim"]),
                # nn.ReLU(),
                nn.Linear(config["hidden_dim"], config["hidden_dim"]),
                # nn.ReLU(),
                nn.Linear(config["hidden_dim"], config["out_dim"])
            )
            # zero init
            # self.heads[name][-1].weight.data.zero_()
            # self.heads[name][-1].bias.data.zero_()
        
        self.setup_functions()

    def set_use_memory_efficient_attention_xformers(
        self, valid: bool, attention_op=None
    ) -> None:
        def fn_recursive_set_mem_eff(module: torch.nn.Module):
            if hasattr(module, "set_use_memory_efficient_attention_xformers"):
                module.set_use_memory_efficient_attention_xformers(valid, attention_op)

            for child in module.children():
                fn_recursive_set_mem_eff(child)

        for module in self.children():
            if isinstance(module, torch.nn.Module):
                fn_recursive_set_mem_eff(module)

    def get_scaling(self, feature):
        # return self.scaling_activation(self.heads["scaling"](feature)) * 0.1
        return self.scaling_activation(self.heads["scaling"](feature)) * self.max_scale
        # return self.scaling_activation(self.heads["scaling"](feature)) / self.tempurature_list[0]

    def get_rotation(self, feature):
        # return self.rotation_activation(self._rotation)
        return self.rotation_activation(self.heads["rotation"](feature))        

    def get_xyz(self, feature, anchors = None):
        # return self._xyz
        if anchors is None:
            anchors = self.anchors
        # return self.xyz_activation(self.heads["xyz"](feature) + self.inverse_xyz_activation((anchors * 0.5 + 0.5))) * 2 - 1
        # return self.heads["xyz"](feature) + self.anchors
        return self.heads["xyz"](feature) + anchors

    def get_opacity(self, feature):
        # return self.opacity_activation(self._opacity)
        return self.opacity_activation(self.heads["opacity"](feature))
        # return torch.ones_like(self.heads["opacity"](feature))

    def get_color(self, feature):
        # return self.color_activation(self._color)
        return self.color_activation(self.heads["color"](feature))
        b, _, _ = feature.shape
        return self.heads["color"](feature).reshape(b, -1, 16, 3)

    def forward(self, renderer=None, background_color=None, cameras=None, images=None):
        outputs_list = []
        anchors = self.anchors
        latent_token = None

        if images is not None:
            self.cache_images = images
        else:
            images = self.cache_images
        b, n, c, h, w = images.shape

        images = rearrange(images, 'b n c h w -> (b n) c h w')
        context = self.encoder(images)
        size = (384 // 16)
        context = rearrange(context[:, 1:], 'b (h w) ... -> b h w ...', h=size, w = size)
        pos_embed_hw = self.position_embedding_hw(context)
        context = rearrange(context, 'b h w c -> b (h w) c')
        pos_embed_hw = rearrange(pos_embed_hw, 'b c h w -> b (h w) c')

        for layer_idx, transformer in enumerate(self.transformers):
            pos_embed = self.position_embedding(anchors)
            pos_embed = self.pos_mapping(pos_embed)
            if layer_idx == 0:
                pos_embed = pos_embed.unsqueeze(0)
            if latent_token is None:
                latent_token = torch.zeros(b, self.n_pts, self.n_dim, device=images.device)
                                                 
            meta = {
                'q_pos_embed': pos_embed,
                'k_pos_embed': pos_embed_hw,

            }
            latent_token = transformer(latent_token, context, meta=meta).sample
            xyz = self.get_xyz(latent_token, anchors)
            opacity = self.get_opacity(latent_token)
            scales = self.get_scaling(latent_token)
            rotations = self.get_rotation(latent_token)
            shs = self.get_color(latent_token)
            if renderer is not None:
                outputs = []
                for i, _cameras in enumerate(cameras):
                    _outputs = []
                    for camera in _cameras:
                        output = renderer(
                            camera,
                            xyz[i],
                            opacity[i],
                            scales[i],
                            rotations[i],
                            shs[i],
                            bg_color = background_color.to(camera.R.device)
                        )
                        _outputs.append(output)
                    _outputs = list2dict(_outputs)
                    outputs.append(_outputs)
                outputs = list2dict(outputs)
            else:
                outputs = {
                    'means3D': xyz,
                    'opacity': opacity,
                    'scale': scales,
                    'rotation': rotations,
                    'color': shs,
                }

            if layer_idx < len(self.transformers) - 1:
                split_xyz = self.split(xyz, scales, rotations, M=2)
                latent_token = repeat(latent_token, 'b n d -> b (m n) d', m=2)
                anchors = split_xyz

            outputs_list.append(outputs)

        return outputs_list

    def get_covariance(self, scaling_modifier=1):
        return self.covariance_activation(self.get_scaling, scaling_modifier, self._rotation)

    def split(self, xyz, scales, rotation, M=2):
        # xyz : N 3
        b, n, _ = xyz.shape
        means = repeat(xyz, 'b n d -> b (m n) d', m=M)
        stds = repeat(scales, 'b n d -> b (m n) d', m=M)
        rotation = rearrange(rotation, 'b n d -> (b n) d')
        rots = build_rotation(rotation) 
        rots = repeat(rots, '(b n) d1 d2 -> b (m n) d1 d2', m=M, b=b)
        samples = torch.randn_like(means) * stds
        samples = torch.einsum('... n i j,... n j->... n i', rots, samples) + means
        return samples
    
    
    def construct_list_of_attributes(self):
        l = ['x', 'y', 'z', 'nx', 'ny', 'nz']
        l.append('opacity')
        for i in range(3):
            l.append('scale_{}'.format(i))
        for i in range(4):
            l.append('rot_{}'.format(i))
        return l

