import os

import numpy as np
import torch
from torch import nn
from plyfile import PlyData, PlyElement

from src.utils.sh import RGB2SH
from simple_knn._C import distCUDA2
from src.utils.general import inverse_sigmoid, get_expon_lr_func, build_rotation, strip_symmetric, \
    build_scaling_rotation
from src.utils.point import BasicPointCloud

import numpy as np
import matplotlib.pyplot as plt

# Let's define a custom activation function
def custom_activation(x):
    """
    A custom activation function designed to produce a double peak in the output distribution.
    The function is a piecewise combination of two sigmoid functions.
    """
    # Parameters of the sigmoid functions
    alpha = 10  # Controls the steepness of the curve
    beta = 0.5  # Controls the position of the transition
    
    # Applying two sigmoids and scaling them
    sigmoid1 = 1 / (1 + torch.exp(-alpha * (x - beta)))
    sigmoid2 = 1 / (1 + torch.exp(-alpha * (x - (1 - beta))))
    
    # Combining the two sigmoids
    return sigmoid1 * (1 - sigmoid2)


class GaussianModel(nn.Module):
    @staticmethod
    def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation):
        L = build_scaling_rotation(scaling_modifier * scaling, rotation)
        actual_covariance = L @ L.transpose(1, 2)
        symm = strip_symmetric(actual_covariance)
        return symm

    def setup_functions(self):

        self.scaling_activation = torch.exp
        self.scaling_inverse_activation = torch.log

        self.covariance_activation = GaussianModel.build_covariance_from_scaling_rotation

        self.opacity_activation = torch.sigmoid
        self.inverse_opacity_activation = inverse_sigmoid

        self.rotation_activation = torch.nn.functional.normalize
        self.color_activation = torch.sigmoid


    def __init__(self, n_pts, active_sh_degree = 0, max_sh_degree=3):
        super().__init__()

        self.active_sh_degree = active_sh_degree
        self.max_sh_degree = max_sh_degree
        features = torch.zeros((n_pts, 3, (self.max_sh_degree + 1) ** 2))
        features_dc = features[:, :, 0:1].transpose(1, 2).contiguous()
        feature_rest = features[:, :, 1:].transpose(1, 2).contiguous()
        
        self.xyz = nn.Parameter(torch.randn(n_pts, 3), requires_grad=True)
        self.color = nn.Parameter(torch.randn(n_pts, 3), requires_grad=True)
        self.opacity = nn.Parameter(torch.randn(n_pts, 1), requires_grad=True)
        # self.opacity = nn.Parameter(torch.zeros(n_pts, 1), requires_grad=True)
        self.scaling = nn.Parameter(torch.randn(n_pts, 3), requires_grad=True)
        self.rotation = nn.Parameter(torch.randn(n_pts, 4), requires_grad=True)
        self._features_dc = nn.Parameter(features_dc.requires_grad_(True))
        self._features_rest = nn.Parameter(feature_rest.requires_grad_(True))

        self.setup_functions()


    def create_from_pcd(self, pcd: BasicPointCloud, deivce):
        fused_point_cloud = torch.tensor(np.asarray(pcd.points)).float().to(deivce)
        fused_color = RGB2SH(torch.tensor(np.asarray(pcd.colors)).float().cuda())
        features = torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2)).float().cuda()
        features[:, :3, 0 ] = fused_color
        features[:, 3:, 1:] = 0.0
        
        print("Number of points at initialisation : ", fused_point_cloud.shape[0])

        # the parameter device may be "cpu", so tensor must move to cuda before calling distCUDA2()
        dist2 = torch.clamp_min(distCUDA2(torch.from_numpy(np.asarray(pcd.points)).float().cuda()), 0.0000001).to(deivce)
        scales = torch.log(torch.sqrt(dist2))[..., None].repeat(1, 3)
        rots = torch.zeros((fused_point_cloud.shape[0], 4), device=deivce)
        rots[:, 0] = 1

        # opacities = inverse_sigmoid(0.1 * torch.ones((fused_point_cloud.shape[0], 1), dtype=torch.float, device=deivce))

        self.xyz = nn.Parameter(fused_point_cloud.requires_grad_(True))
        self.scaling = nn.Parameter(scales.requires_grad_(True))
        self.rotation = nn.Parameter(rots.requires_grad_(True))
        # self.opacity = nn.Parameter(opacities.requires_grad_(True))


    def training_setup(self, training_args, scene_extent: float):
        self.spatial_lr_scale = scene_extent
        # override spatial_lr_scale if provided
        if training_args.spatial_lr_scale > 0:
            self.spatial_lr_scale = training_args.spatial_lr_scale

        # some tensor may still in CPU, move to the same device as the _xyz
        l = [
            {'params': [self._features_dc], 'lr': training_args.feature_lr, "name": "f_dc"},
            {'params': [self._features_rest], 'lr': training_args.feature_lr / 20.0, "name": "f_rest"},
            {'params': [self.xyz], 'lr': training_args.position_lr_init * self.spatial_lr_scale, "name": "xyz"},
            {'params': [self.opacity], 'lr': training_args.opacity_lr, "name": "opacity"},
            {'params': [self.scaling], 'lr': training_args.scaling_lr, "name": "scaling"},
            {'params': [self.rotation], 'lr': training_args.rotation_lr, "name": "rotation"},
            {'params': [self.color], 'lr': training_args.opacity_lr, "name": "color"},
        ]

        self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15)
    #     self.xyz_scheduler_args = get_expon_lr_func(lr_init=training_args.position_lr_init * self.spatial_lr_scale,
    #                                                 lr_final=training_args.position_lr_final * self.spatial_lr_scale,
    #                                                 lr_delay_mult=training_args.position_lr_delay_mult,
    #                                                 max_steps=training_args.position_lr_max_steps)
        
    # def update_learning_rate(self, iteration):
    #     ''' Learning rate scheduling per step '''
    #     for param_group in self.optimizer.param_groups:
    #         if param_group["name"] == "xyz":
    #             lr = self.xyz_scheduler_args(iteration)
    #             param_group['lr'] = lr
    #             return lr

    def oneupSHdegree(self):
        if self.active_sh_degree < self.max_sh_degree:
            self.active_sh_degree += 1


    def set_use_memory_efficient_attention_xformers(
        self, valid: bool, attention_op=None
    ) -> None:
        def fn_recursive_set_mem_eff(module: torch.nn.Module):
            if hasattr(module, "set_use_memory_efficient_attention_xformers"):
                module.set_use_memory_efficient_attention_xformers(valid, attention_op)

            for child in module.children():
                fn_recursive_set_mem_eff(child)

        for module in self.children():
            if isinstance(module, torch.nn.Module):
                fn_recursive_set_mem_eff(module)

    def get_features(self):
        features_dc = self._features_dc
        features_rest = self._features_rest
        return torch.cat((features_dc, features_rest), dim=1)


    def get_scaling(self):
        # return self.scaling_activation(self._scaling)
        return self.scaling_activation(self.scaling)

    def get_rotation(self):
        # return self.rotation_activation(self._rotation)
        return self.rotation_activation(self.rotation)    

    def get_xyz(self):
        # return self._xyz
        return self.xyz

    def get_opacity(self):
        # return self.opacity_activation(self._opacity)
        return self.opacity_activation(self.opacity)
        # return torch.ones_like(self.opacity)

    def get_color(self):
        # return self.color_activation(self._color)
        return self.color_activation(self.color)
    
    def get_covariance(self, scaling_modifier=1):
        return self.covariance_activation(self.get_scaling, scaling_modifier, self._rotation)
    
    def construct_list_of_attributes(self):
        l = ['x', 'y', 'z', 'nx', 'ny', 'nz']
        l.append('opacity')
        for i in range(3):
            l.append('scale_{}'.format(i))
        for i in range(4):
            l.append('rot_{}'.format(i))
        return l

