
import numpy as np
import torch
import torch.utils.data
import warnings
import open3d as o3d
import trimesh

from models.metrics import get_metrics, compute_joint_error, get_rot_matrix
import pytorch3d

from utils.reconstruct import *
from diff_utils.helpers import *


from optimizer import optimize_pose,PoseEstimator
from ArtImage_data import arti_utils, Art_DataGen
import pyvista as pv
from dataloader.pc_loader import PCloader
from pytorch3d.ops import sample_points_from_meshes
from pytorch3d.loss import chamfer_distance
from pytorch3d.structures import Meshes
import os
import torch

def compute_chamfer(recon_pts,gt_pts):
	with torch.no_grad():
		recon_pts = recon_pts.cuda()
		gt_pts = gt_pts.cuda()
		dist,_ = chamfer_distance(recon_pts,gt_pts,batch_reduction=None)
		dist = dist.item()
	return dist

def test_mesh(gt_mesh, recon_mesh, out_file, mesh_name, return_value=False, return_sampled_pc=False, prioritize_cov=False,
         pc_size=None):
    n_samples = 30000
    vertices = torch.tensor(gt_mesh.vertices, dtype=torch.float32)  
    faces = torch.tensor(gt_mesh.faces, dtype=torch.int64)  
    gt_mesh = Meshes(verts=[vertices], faces=[faces])

    vertices = torch.tensor(recon_mesh.vertices, dtype=torch.float32)  
    faces = torch.tensor(recon_mesh.faces, dtype=torch.int64)  
    recon_mesh = Meshes(verts=[vertices], faces=[faces])

    gt_pts = sample_points_from_meshes(gt_mesh, num_samples=n_samples)
    recon_pts = sample_points_from_meshes(recon_mesh, num_samples=n_samples)

    loss_chamfer = (compute_chamfer(recon_pts, gt_pts) + compute_chamfer(gt_pts, recon_pts)) * 0.5

    return loss_chamfer

def test_generation():
    
    if args.resume == 'finetune':  
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            model = CombinedModel.load_from_checkpoint(specs["modulation_ckpt_path"], specs=specs, strict=False)
            diffusion_ckpt = torch.load(specs["diffusion_ckpt_path"])
            new_state_dict = {}
            for k, v in diffusion_ckpt['state_dict'].items():
                if k.startswith("diffusion_model."):  
                    new_key = k.replace("diffusion_model.", "", 1)  
                    new_state_dict[new_key] = v
            model.diffusion_model.load_state_dict(new_state_dict, strict=True)
            score_ckpt = torch.load(specs["score_ckpt_path"])
            new_state_dict = {}
            for k, v in score_ckpt['state_dict'].items():
                if k.startswith("scorenet."):  
                    new_key = k.replace("scorenet.", "", 1)  
                    new_state_dict[new_key] = v
            model.scorenet.load_state_dict(new_state_dict, strict=True)
            model = model.cuda().eval()
    else:
        ckpt = "{}.ckpt".format(args.resume) if args.resume == 'last' else "epoch={}.ckpt".format(args.resume)
        resume = os.path.join(args.exp_dir, ckpt)
        model = CombinedModel.load_from_checkpoint(resume, specs=specs).cuda().eval()

    conditional = specs["diffusion_model_specs"]["cond"]

    if not conditional:
        samples = model.diffusion_model.generate_unconditional(args.num_samples)
        plane_features = model.vae_model.decode(samples)
        for i in range(len(plane_features)):
            plane_feature = plane_features[i].unsqueeze(0)
            mesh.create_mesh(model.sdf_model, plane_feature, recon_dir + "/{}_recon".format(i), N=128,
                             max_batch=2 ** 21, from_plane_features=True)

    else:
        test_split = [line.strip() for line in open(specs["TestSplit"], "r").readlines()]
        
        test_dataset = PCloader(specs["DataSource"], test_split, pc_size=specs.get("PCsize", 1024),
                                return_filename=True)
        test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, num_workers=0)

        with tqdm(test_dataloader) as pbar:
            for idx, data in enumerate(pbar):
                pbar.set_description("Files generated: {}/{}".format(idx, len(test_dataloader)))

                atc = data['atc']
                filename = data['file_name']
                num_parts = specs['num_parts']

                filename = filename[0]  

                cls_name = filename.split("/")[-4]
                mesh_name = filename.split("/")[-1].split('.')[0]

                print(mesh_name)
                outdir = os.path.join(recon_dir, "{}/{}".format(cls_name, mesh_name))
                os.makedirs(outdir, exist_ok=True)

                camera_partial_pc = data['pts'].cuda()
                pts_feat = model.scorenet.pts_encoder(camera_partial_pc)
                data['pts_feat'] = pts_feat

                gt_seg = data['seg']
                pred_seg = torch.argmax(model.scorenet.seg_encoder(camera_partial_pc), dim=2)
                correct = (pred_seg.cpu() == gt_seg).sum()
                total_points = gt_seg.numel()  
                seg_acc = correct.float() / total_points
                print('seg_acc:', seg_acc)

                data['type'] = 'pose'
                pred_pose, pred_pose_q_wxyz, average_pred_pose_q_wxyz = model.scorenet.pred_pose_func(data=data, repeat_num=1, save_path=None, return_average_res=True)

                gt_repeat_pose = data['gt_pose'].repeat(1,pred_pose.shape[1],1)[0]
                rot_error, trans_error = get_metrics(
                    pred_pose[0],
                    gt_repeat_pose.cuda(),
                    pose_mode=specs['pose_mode'],
                    o2c_pose=specs['o2c_pose'],
                )

                data['type'] = 'joint'
                data['gt_joint'] = torch.cat((data['gt_xyz'], data['gt_rpy']), dim=1)
                pred_joint, average_pred_joint = model.scorenet.pred_joint_func(data=data, repeat_num=1, save_path=None, return_average_res=True)

                joint_xyz_error, joint_rpy_error = compute_joint_error(
                    average_pred_joint.cpu(),
                    data['gt_joint']
                )
                print('joint_xyz_error:', joint_xyz_error, 'joint_rpy_error:', joint_rpy_error)

                base_pose = torch.mean(pred_pose, dim=1)
                pred_base_rot = base_pose[0, :6]
                pred_base_trans = base_pose[0, 6:]
                pred_base_rot = pytorch3d.transforms.rotation_6d_to_matrix(pred_base_rot).permute(1, 0)
                joint_xyz = average_pred_joint[0, :3]
                joint_rpy = average_pred_joint[0, 3:]

                pred_base_T = torch.eye(4)
                pred_base_T[:3, :3] = pred_base_rot
                pred_base_T [:3, 3] = pred_base_trans
                joint_xyz_pcd = o3d.geometry.PointCloud()
                joint_xyz_pcd.points = o3d.utility.Vector3dVector(joint_xyz.unsqueeze(0).cpu().numpy())
                joint_xyz_pcd.transform(np.linalg.inv(pred_base_T ))
                joint_xyz = np.asarray(joint_xyz_pcd.points)
                joint_rpy = np.dot(joint_rpy.cpu().numpy(), np.linalg.inv(pred_base_rot.cpu().numpy()).T)

                sphere = o3d.geometry.TriangleMesh.create_sphere(radius=0.01)
                sphere.paint_uniform_color([0, 0, 1])
                
                sphere.translate(joint_xyz.squeeze(0))

                sphere1 = o3d.geometry.TriangleMesh.create_sphere(radius=0.01)
                sphere1.paint_uniform_color([1, 0, 0])
                
                sphere1.translate((joint_xyz[0]+joint_rpy.cpu().numpy()))

                T = torch.eye(4)
                T[:3, :3] = pred_base_rot
                T[:3, 3] = pred_base_trans

                estimate_pts = torch.from_numpy(np.asarray(camera_pcd.points)).unsqueeze(0)

                canonical_atc = torch.from_numpy(np.array([0.00])).float()
                samples, perturbed_pc = model.diffusion_model.generate_from_pc(estimate_pts.cuda(),
                                                                               batch=args.num_samples,
                                                                               save_pc=outdir, return_pc=True,
                                                                               perturb_pc=False)

                atc_emb = model.vae_model.condition_encoder(canonical_atc.float().cuda())
                atc_emb = atc_emb.repeat(args.num_samples, 1)
                samples_atc = torch.cat([samples, atc_emb], dim=1)
                plane_features = model.vae_model.decode(samples_atc)

                for i in range(len(plane_features)):
                    plane_feature = plane_features[i].unsqueeze(0)
                    mesh.create_mesh(model.sdf_model, plane_feature, outdir + "/{}_recon".format(i), N=128,
                                     max_batch=2 ** 16, from_plane_features=True)

                
                for i in range(len(plane_features)):
                    mesh_path_pred = outdir + "/{}_recon.ply".format(i)
                    mesh_path_gt = f'data/train/watertight_obj_annotations/{mesh_name}.obj'
                    gt_mesh = trimesh.load(mesh_path_gt)
                    pred_mesh = trimesh.load(mesh_path_pred)
                    mesh_o3d_gt = o3d.io.read_triangle_mesh(mesh_path_gt)
                    mesh_o3d_pred = o3d.io.read_triangle_mesh(mesh_path_pred)
                    o3d.visualization.draw_geometries([mesh_o3d_pred, mesh_o3d_gt])
                    cd = test_mesh(gt_mesh, pred_mesh, None, None, return_value=True, prioritize_cov=True)
                    print('pred_cd:', cd)

                mesh_path = outdir + "/{}_recon.ply".format(i)
                mesh_ply = o3d.io.read_triangle_mesh(mesh_path)
                mesh_pts = np.asarray(mesh_ply.sample_points_poisson_disk(
                    number_of_points=1024,
                    init_factor=5
                ).points)

                mesh_pts = torch.from_numpy(mesh_pts).unsqueeze(0).to(torch.float32).cuda()
                pred_mesh_seg = torch.argmax(seg_model.scorenet.seg_encoder(mesh_pts), dim=2)
                
                cad_pts_lst = [None]*num_parts
                camera_pts_lst = [None]*num_parts
                for part_id in range(num_parts):
                    cad_part_pts = mesh_pts[pred_mesh_seg==part_id, :]
                    cad_pts_lst[part_id] = cad_part_pts.double()
                    camera_part_pts = camera_partial_pc[pred_seg==part_id, :]
                    camera_pts_lst[part_id] = camera_part_pts.double()
                joint_xyz = torch.from_numpy(joint_xyz).squeeze(0).cuda()
                init_joint_state = atc
                init_base_r = pred_base_rot
                init_base_t = pred_base_trans
                joint_type = 'revolute'
                device = 'cuda:0'
                reg_weight = 0
                pose_estimator = PoseEstimator(num_parts=num_parts, init_base_r=init_base_r, init_base_t=init_base_t,
                                               init_joint_state=init_joint_state, device=device,
                                               joint_type=joint_type, reg_weight=reg_weight)

                part_weight = [1, 10]
                print(torch.mean(rot_error), torch.mean(trans_error))
                base_transform, relative_transform_all, new_joint_params_all, joint_state = optimize_pose(pose_estimator,
                                                                                                          camera_pts_lst, cad_pts_lst,
                                                                                                          joint_xyz, joint_rpy, part_weight)

                gt_base_rot = data['gt_rot'][0]
                gt_base_trans = data['gt_trans'][0]

                rot_base_diff = arti_utils.rot_diff_degree(base_transform[:3, :3].cpu().numpy(), gt_base_rot[:3, :3].cpu().numpy())
                dis_base_diff = np.linalg.norm(base_transform[:3, 3].cpu().numpy() - gt_base_trans.cpu().numpy())
                print(rot_base_diff, dis_base_diff)


if __name__ == "__main__":

    import argparse

    arg_parser = argparse.ArgumentParser()
    arg_parser.add_argument(
        "--exp_dir", "-e", required=True,
        help="This directory should include experiment specifications in 'specs_Art_test.json,' and logging will be done in this directory as well.",
    )
    arg_parser.add_argument(
        "--resume", "-r", default=None,
        help="continue from previous saved logs, integer value, 'last', or 'finetune'",
    )

    arg_parser.add_argument("--num_samples", "-n", default=5, type=int,
                            help='number of samples to generate and reconstruct')

    arg_parser.add_argument("--filter", default=False, help='whether to filter when sampling conditionally')
    arg_parser.add_argument("--class_name", "-c", default='laptop', type=str)

    args = arg_parser.parse_args()
    specs = json.load(open(os.path.join(args.exp_dir, "specs_Art_test.json")))
    print(specs["Description"])

    recon_dir = os.path.join(args.exp_dir, "recon")
    os.makedirs(recon_dir, exist_ok=True)

    test_generation()