import os

import numpy as np
import torch
from torch import nn
from plyfile import PlyData, PlyElement

from src.utils.sh import RGB2SH
from src.utils.general import inverse_sigmoid, get_expon_lr_func, build_rotation, strip_symmetric, \
    build_scaling_rotation
from src.utils.point import BasicPointCloud

class GaussianModel(nn.Module):
    @staticmethod
    def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation):
        L = build_scaling_rotation(scaling_modifier * scaling, rotation)
        actual_covariance = L @ L.transpose(1, 2)
        symm = strip_symmetric(actual_covariance)
        return symm

    def setup_functions(self):
        self.covariance_activation = GaussianModel.build_covariance_from_scaling_rotation

    def __init__(self, feature_space, gs_head):
        super().__init__()

        self.feature_space = feature_space.points
        
        self.heads = gs_head
    
        self.active_sh_degree = 0
        self.max_sh_degree = 0
        self.max_radii2D = torch.empty(0)
        self.setup_functions()

    def extra_params_to(self, device, dtype):
        self.max_radii2D = self.max_radii2D.to(device=device, dtype=dtype)
        self.xyz_gradient_accum = self.xyz_gradient_accum.to(device=device, dtype=dtype)
        self.denom = self.denom.to(device=device, dtype=dtype)

    def get_scaling(self):
        return self.heads(self.feature_space, "scaling")

    def get_rotation(self):
        return self.heads(self.feature_space, "rotation")

    def get_xyz(self):
        return self.heads(self.feature_space, "xyz")

    def get_opacity(self):
        return self.heads(self.feature_space, "opacity")

    def get_color(self):
        return self.heads(self.feature_space, "rgb")

    def get_covariance(self, scaling, rotation, scaling_modifier=1):
        return self.covariance_activation(scaling, scaling_modifier, rotation)

    def save_ply(self, path, point_cloud, scaling, rotation, scaling_modifier=1):
        covariance = self.get_covariance(scaling, rotation, scaling_modifier)
        point_cloud.save_ply(path, covariance)


    def construct_list_of_attributes(self):
        l = ['x', 'y', 'z', 'nx', 'ny', 'nz']
        # All channels except the 3 DC
        for i in range(self._features_dc.shape[1] * self._features_dc.shape[2]):
            l.append('f_dc_{}'.format(i))
        for i in range(self._features_rest.shape[1] * self._features_rest.shape[2]):
            l.append('f_rest_{}'.format(i))
        l.append('opacity')
        for i in range(self._scaling.shape[1]):
            l.append('scale_{}'.format(i))
        for i in range(self._rotation.shape[1]):
            l.append('rot_{}'.format(i))
        return l

    def save_ply(self, path):
        os.makedirs(os.path.dirname(path), exist_ok=True)

        xyz = self.get_xyz().cpu().numpy()
        normals = np.zeros_like(xyz)
        color = self.get_color().cpu().numpy()
        # f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
        # f_rest = self._features_rest.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
        opacities = self.get_opacity().detach().cpu().numpy()
        scale = self.get_scaling().detach().cpu().numpy()
        rotation = self.get_rotation().detach().cpu().numpy()

        l = ['x', 'y', 'z', 'nx', 'ny', 'nz', 'r', 'g', 'b']
        # All channels except the 3 DC
        # for i in range(self._features_dc.shape[1] * self._features_dc.shape[2]):
        #     l.append('f_dc_{}'.format(i))
        # for i in range(self._features_rest.shape[1] * self._features_rest.shape[2]):
        #     l.append('f_rest_{}'.format(i))
        l.append('opacity')
        for i in range(scale.shape[1]):
            l.append('scale_{}'.format(i))
        for i in range(rotation.shape[1]):
            l.append('rot_{}'.format(i))

        dtype_full = [(attribute, 'f4') for attribute in l]

        # dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()]

        elements = np.empty(xyz.shape[0], dtype=dtype_full)
        # attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1)
        attributes = np.concatenate((xyz, normals, color, opacities, scale, rotation), axis=1)
        elements[:] = list(map(tuple, attributes))
        el = PlyElement.describe(elements, 'vertex')
        PlyData([el]).write(path)
