import os
import json
import argparse
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from plyfile import PlyData, PlyElement

from utils.geom_utils import canonicalize

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset_root", type=str,
                        default="FlatLab_Stage_2_Strategy_B",
                        help="Root directory containing train / test_seen / test_unseen")
    parser.add_argument("--batch_size", type=int, default=8, help="Training batch size")
    parser.add_argument("--num_workers", type=int, default=4, help="Num workers for DataLoader")
    parser.add_argument("--epochs", type=int, default=100, help="Number of training epochs")
    parser.add_argument("--gpu", type=str, default='cuda:0', help="GPU device (e.g. cuda:0, cpu)")
    parser.add_argument("--lr", type=float, default=1e-4, help="Initial learning rate")
    parser.add_argument("--num_points", type=int, default=8192, help="Number of points sampled per point cloud")
    return parser.parse_args()

def load_pose_txt(txt_path):
    with open(txt_path, "r") as f:
        data = json.load(f)
    ee_L = np.array(data["ee_pose_L"], dtype=np.float32)
    ee_R = np.array(data["ee_pose_R"], dtype=np.float32)
    return np.concatenate([ee_L, ee_R], axis=0)

def load_ply_xyz(ply_path, num_points=8192):
    ply = PlyData.read(ply_path)
    vertex = ply["vertex"]
    xyz = np.vstack([vertex["x"], vertex["y"], vertex["z"]]).T.astype(np.float32)

    if {"red", "green", "blue"}.issubset(vertex.data.dtype.names):
        rgb = np.vstack([
            vertex["red"], vertex["green"], vertex["blue"]
        ]).T.astype(np.float32) / 255.0
    else:
        rgb = np.zeros_like(xyz)

    xyzrgb = np.concatenate([xyz, rgb], axis=1)

    if xyzrgb.shape[0] >= num_points:
        idx = np.random.choice(xyzrgb.shape[0], num_points, replace=False)
    else:
        idx = np.random.choice(xyzrgb.shape[0], num_points, replace=True)

    return xyzrgb[idx]

class PoseDataset(Dataset):
    def __init__(self, root, split="train", num_points=8192):
        self.root = root
        self.split = split
        self.num_points = num_points
        self.samples = []

        split_dir = os.path.join(root, split)
        for obj in sorted(os.listdir(split_dir)):
            obj_dir = os.path.join(split_dir, obj)
            if not os.path.isdir(obj_dir):
                continue
            for rec in sorted(os.listdir(obj_dir)):
                rec_dir = os.path.join(obj_dir, rec)
                if not os.path.isdir(rec_dir):
                    continue
                init_ply = os.path.join(rec_dir, "initial_env.ply")
                start_txt = os.path.join(rec_dir, "start_pos.txt")
                end_txt = os.path.join(rec_dir, "end_pos.txt")
                if os.path.exists(init_ply) and os.path.exists(start_txt) and os.path.exists(end_txt):
                    self.samples.append((init_ply, start_txt, end_txt))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        init_ply, start_txt, end_txt = self.samples[idx]
        pc6 = load_ply_xyz(init_ply, self.num_points)
        xyz = pc6[:, :3]
        xyz_norm, transform = canonicalize(xyz)
        pc6[:, :3] = xyz_norm

        start_pose = load_pose_txt(start_txt)
        end_pose = load_pose_txt(end_txt)

        def transform_pose_positions(pose28, trans):
            p = pose28.copy().astype(np.float32)
            for base in [0, 7, 14, 21]:
                p[base:base+3] = (p[base:base+3] - trans["center"]) / trans["scale"]
            return p

        gt = np.concatenate([start_pose, end_pose], axis=0)
        gt = transform_pose_positions(gt, transform)

        pc_tensor = torch.from_numpy(pc6.astype(np.float32))
        gt_tensor = torch.from_numpy(gt.astype(np.float32))
        return pc_tensor, gt_tensor, init_ply, transform