import os
from typing import Tuple
import pickle
import argparse
import numpy as np
from tqdm import tqdm
import torch
from torch.utils.data import Dataset

import sys
sys.path.insert(0, "~/repos/bbox_scoring")
from data_utils import get_fov_flag, get_transform
from heuristics import check_type_and_convert

sys.path.insert(0, "~/repos/modest_pp/generate_cluster_mask")
from discovery_utils.pointcloud_utils import load_velo_scan
from pcdet.utils import calibration_kitti


parser = argparse.ArgumentParser()
parser.add_argument("--size-ratio", type=int, default=4)
parser.add_argument("--out-dir", type=str, default="~/datasets/boxed_shape_unscaled/lyft_train")
args = parser.parse_args()


class LyftTrain(Dataset):
    def __init__(self):
        idx_path = os.path.expanduser("~/datasets/lyft_fw70_2m_train_idx_hasobj.txt")  # filtered out scenes with no objects
        self.train_set = [int(x) for x in open(idx_path).readlines()]

        gt_info = pickle.load(open("~/modest_pp/downstream/OpenPCDet/data/lyft/kitti_infos_train.pkl", "rb"))
        self.gt_dict = {}
        for gt in gt_info:
            self.gt_dict[int(gt['point_cloud']['lidar_idx'])] = gt['annos']

        self.ptc_path = "~/lyft/training/velodyne/"
        if not os.path.isdir(self.ptc_path):
            self.ptc_path = "~/datasets/lyft_release_test/training/velodyne/"
        self.calib_path = "~/datasets/lyft_release_test/training/calib/"
        self.p2score_path = "~/lyft/training/pp_score_fw70_2m_r0.3/"
        if not os.path.isdir(self.p2score_path):
            self.p2score_path = "~/datasets/lyft_release_test/training/pp_score_fw70_2m_r0.3/"

    def __len__(self):
        return len(self.train_set)

    def __getitem__(self, item):
        idx = self.train_set[item]
        ptc = load_velo_scan(os.path.join(self.ptc_path, f"{idx:06d}.bin"))
        pp_score = np.load(os.path.join(self.p2score_path, f"{idx:06d}.npy"))
        curr_boxes = self.gt_dict[idx]['gt_boxes_lidar']
        categories = self.gt_dict[idx]['name']

        curr_calib = calibration_kitti.Calibration((os.path.join(self.calib_path, f"{idx:06d}.txt")))
        curr_fov_flag = get_fov_flag(ptc, (1024, 1224), curr_calib)
        ptc = ptc[curr_fov_flag]
        pp_score = pp_score[curr_fov_flag]

        return ptc, pp_score, curr_boxes, categories


def get_scale(
    ptc: torch.Tensor,  # (N_points, 4)
    boxs: torch.Tensor,  # (M, 7)
) -> Tuple[torch.Tensor, torch.Tensor]:
    ptc = torch.cat([ptc[:, :3], torch.ones_like(ptc[:, :1])], dim=1)  # [N_points, 4]
    ptc = ptc.unsqueeze(dim=0).expand(boxs.shape[0], ptc.shape[0], 4)  # [M, N_points, 4]

    trs = get_transform(boxs)  # [M, 4, 4]
    ptc = torch.bmm(ptc, trs)[:, :, :3]  # [M, N_points, 3]
    scale = ptc / (boxs[:, 3:6].unsqueeze(dim=1) * 0.5)

    scale = torch.max(torch.abs(scale), dim=2).values
    return ptc, scale


if __name__ == "__main__":
    args.out_dir = f"{args.out_dir}_{args.size_ratio}"
    print(f"save to {args.out_dir}")
    os.makedirs(args.out_dir, exist_ok=True)
    dataset = LyftTrain()

    category_count = {}
    for ptc, pp_score, boxes, categories in tqdm(dataset):
        ptc, pp_score, boxes, np_input, _shape = check_type_and_convert(ptc, pp_score, boxes, extreme=0.0)
        ptc, scale = get_scale(ptc, boxes)
        for box_id in range(boxes.shape[0]):
            category = str(categories[box_id])
            if category not in category_count:
                c_idx = category_count[category] = 0
                os.mkdir(os.path.join(args.out_dir, category))
            else:
                c_idx = category_count[category] = category_count[category] + 1
            save_path = os.path.join(args.out_dir, category, f"{c_idx:06d}.pkl")

            mask = scale[box_id] < args.size_ratio
            box_ptc = ptc[box_id][mask]
            box_pp = pp_score[mask].unsqueeze(dim=1)
            # box_ptc = torch.cat([box_ptc, box_pp], dim=1)
            # box_ptc = box_ptc.cpu().numpy()
            data_dict = {
                "ptc": box_ptc.cpu().numpy(),
                "pp_score": box_pp.cpu().numpy(),
                "size": boxes[box_id, 3:6].cpu().numpy(),
                "translation": boxes[box_id, 0:3].cpu().numpy(),
            }
            # np.save(save_path, box_ptc)
            with open(save_path, "wb") as f:
                pickle.dump(data_dict, f)
