import torch 
import torch.nn as nn 
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import numpy as np
from utils import generate_uniform_unit_sphere_projections
import os

from data.shapenet_dataset2 import ShapeNet15kPointClouds


npoints = 2048
dataroot = "data/ShapeNetV2/ShapeNetCore.v2.PC15k"


categories = ['airplane', 'bag', 'basket', 'bathtub', 'bed', 'bench', 'bottle',
 'bowl', 'bus', 'cabinet', 'can', 'camera', 'cap', 'car', 'chair',
 'clock', 'dishwasher', 'monitor', 'table', 'telephone', 'tin_can',
 'tower', 'train', 'keyboard', 'earphone', 'faucet', 'file', 'guitar',
 'helmet', 'jar', 'knife', 'lamp', 'laptop', 'speaker', 'mailbox',
 'microphone', 'microwave', 'motorcycle', 'mug', 'piano', 'pillow',
 'pistol', 'pot', 'printer', 'remote_control', 'rifle', 'rocket',
 'skateboard', 'sofa', 'stove', 'vessel', 'washer', 'cellphone',
 'birdhouse', 'bookshelf']

parent_dir = "preprocessed_dataset/point_cloud"
os.makedirs(parent_dir, exist_ok=True)
os.makedirs(parent_dir+"/train", exist_ok=True)
os.makedirs(parent_dir+"/val", exist_ok=True)


for category in categories:

    print(category)

    train_dataset = ShapeNet15kPointClouds(root_dir=dataroot,
        categories=[category], split='train',
        tr_sample_size=npoints,
        te_sample_size=npoints,
        scale=1.,
        reflow=False,
        normalize_per_shape=True,
        normalize_std_per_axis=False,
        random_subsample=True)

    test_dataset = ShapeNet15kPointClouds(root_dir=dataroot,
        categories=[category], split='val',
        tr_sample_size=npoints,
        te_sample_size=npoints,
        scale=1.,
        reflow=False,
        normalize_per_shape=True,
        normalize_std_per_axis=False,
        all_points_mean=train_dataset.all_points_mean,
        all_points_std=train_dataset.all_points_std)


    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1)
    test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1)

    list_x = list()
    for x in train_dataloader:
        point_cloud = x["train_points"]
        list_x.append(point_cloud)

    list_x = torch.cat(list_x, dim=0)
    print(f"Category: {category}, shape {list_x.shape}")
    torch.save(list_x, f"{parent_dir}/train/{category}.pt")

    list_x = list()
    for x in test_dataloader:
        point_cloud = x["train_points"]
        list_x.append(point_cloud)

    list_x = torch.cat(list_x, dim=0)
    print(f"Category: {category}, shape {list_x.shape}")
    torch.save(list_x, f"{parent_dir}/val/{category}.pt")