import numpy as np
import plyfile

import torch
import torch.nn.functional as F
# from models.pointnet_sem_seg import get_model, get_loss
from models.pointnet_cls import get_model, get_loss
from models.pointnet_utils import feature_transform_reguliarzer

path = '/home/comp/csxfhuang/development/gaussian-splatting/output/2058e832-5/point_cloud/iteration_7000/point_cloud.ply'

pointnet = get_model(8, normal_channel=False)
# pointnet = get_model(8)
pointnet_loss = get_loss()
plydata = plyfile.PlyData.read(path)
code_length = 8
binary_message = torch.randint(0, 2, (1, code_length)).float().cuda()

xyz = np.stack((np.asarray(plydata.elements[0]["x"]),
                np.asarray(plydata.elements[0]["y"]),
                np.asarray(plydata.elements[0]["z"])),  axis=1)
xyz = torch.tensor(xyz, dtype=torch.float, device="cuda")[:1000, ...].unsqueeze(0).permute(0, 2, 1).reshape(2, 3, -1)
binary_message = binary_message.tile(xyz.shape[0], 1)
# verts_tensor = verts_tensor.cuda()
# verts_tensor_ = torch.cat((verts_tensor, binary_message), dim=1)

print(xyz.shape)
print(binary_message.shape)

pointnet = pointnet.cuda()

pred, feat = pointnet(xyz)
print(pred.shape)
print(feat.shape)

optimizer = torch.optim.Adam(
    pointnet.parameters(),
    lr=0.0001,
    betas=(0.9, 0.999),
    eps=1e-08,
    weight_decay=1e-4
)

# while True:
for i in range(100000):
    with torch.enable_grad():
        loss = 0
        pred, feat = pointnet(xyz)
        optimizer.zero_grad()
        mse_loss = F.mse_loss(pred, binary_message)
        mat_diff_loss = feature_transform_reguliarzer(feat)
        loss += mse_loss + mat_diff_loss
        print('step: ', i, ' loss: ', loss.item(), 'binary_message: ', binary_message[0].float(), 'pred: ', (pred[0] > 0.5).float())
        loss.backward()
        optimizer.step()
        # break
        if loss < 0.01:
            break
