# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
#
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.

""" copied and modified from https://github.com/stevenygd/PointFlow/blob/master/datasets.py """
# Added automatic download and collate_fn (Erik Bekkers 27 May 2024)
import os
# import open3d as o3d
import time
import torch
import numpy as np
from loguru import logger
from torch.utils.data import Dataset
from torch.utils import data
import random
import tqdm
# from datasets.data_path import get_path
from PIL import Image
OVERFIT = 0
import zipfile



def get_path(dataname=None):
    dataset_path = {}
    dataset_path['pointflow'] = [
        './datasets/ShapeNetCore.v2.PC15k/'

    ]
    dataset_path['clip_forge_image'] = [
            './datasets/shapenet_render/'
            ]

    if dataname is None:
        return dataset_path
    else:
        assert(
            dataname in dataset_path), f'not found {dataname}, only: {list(dataset_path.keys())}'
        for p in dataset_path[dataname]:
            print(f'searching: {dataname}, get: {p}')
            if os.path.exists(p):
                return p
        ValueError(
            f'all path not found for {dataname}, please double check: {dataset_path[dataname]}; or edit the datasets/data_path.py ')


def get_cache_path():
    cache_list = ['/workspace/data_cache_local/data_stat/',
                  '/workspace/data_cache/data_stat/']
    for p in cache_list:
        if os.path.exists(p):
            return p
    ValueError(
        f'all path not found for {cache_list}, please double check: or edit the datasets/data_path.py ')
    



# taken from https://github.com/optas/latent_3d_points/blob/
# 8e8f29f8124ed5fc59439e8551ba7ef7567c9a37/src/in_out.py
synsetid_to_cate = {
    '02691156': 'airplane',
    '02773838': 'bag',
    '02801938': 'basket',
    '02808440': 'bathtub',
    '02818832': 'bed',
    '02828884': 'bench',
    '02876657': 'bottle',
    '02880940': 'bowl',
    '02924116': 'bus',
    '02933112': 'cabinet',
    '02747177': 'can',
    '02942699': 'camera',
    '02954340': 'cap',
    '02958343': 'car',
    '03001627': 'chair',
    '03046257': 'clock',
    '03207941': 'dishwasher',
    '03211117': 'monitor',
    '04379243': 'table',
    '04401088': 'telephone',
    '02946921': 'tin_can',
    '04460130': 'tower',
    '04468005': 'train',
    '03085013': 'keyboard',
    '03261776': 'earphone',
    '03325088': 'faucet',
    '03337140': 'file',
    '03467517': 'guitar',
    '03513137': 'helmet',
    '03593526': 'jar',
    '03624134': 'knife',
    '03636649': 'lamp',
    '03642806': 'laptop',
    '03691459': 'speaker',
    '03710193': 'mailbox',
    '03759954': 'microphone',
    '03761084': 'microwave',
    '03790512': 'motorcycle',
    '03797390': 'mug',
    '03928116': 'piano',
    '03938244': 'pillow',
    '03948459': 'pistol',
    '03991062': 'pot',
    '04004475': 'printer',
    '04074963': 'remote_control',
    '04090263': 'rifle',
    '04099429': 'rocket',
    '04225987': 'skateboard',
    '04256520': 'sofa',
    '04330267': 'stove',
    '04530566': 'vessel',
    '04554684': 'washer',
    '02992529': 'cellphone',
    '02843684': 'birdhouse',
    '02871439': 'bookshelf',
    # '02858304': 'boat', no boat in our dataset, merged into vessels
    # '02834778': 'bicycle', not in our taxonomy
}
cate_to_synsetid = {v: k for k, v in synsetid_to_cate.items()}


# Moditified to need a root dir, also added auto download
class ShapeNet15kPointClouds(Dataset):
    def __init__(self,
                 root,  
                 categories=['airplane'],
                 tr_sample_size=10000,
                 te_sample_size=10000,
                 split='train',
                 scale=1.,
                 normalize_per_shape=False,
                 normalize_shape_box=False,
                 random_subsample=False,
                 sample_with_replacement=1,
                 normalize_std_per_axis=False,
                 normalize_global=False,
                 recenter_per_shape=False,
                 all_points_mean=None,
                 all_points_std=None,
                 all_points_std_global=None,
                 input_dim=3, 
                 clip_forge_enable=0, clip_model=None
                 ):
        
        self.root = root
        self.download_and_extract()

        self.clip_forge_enable = clip_forge_enable 
        if clip_forge_enable:
            import clip
            _, self.clip_preprocess = clip.load(clip_model)
        if self.clip_forge_enable:
            self.img_path = []
            img_path = get_path('clip_forge_image') 

        self.normalize_shape_box = normalize_shape_box
        # root_dir = get_path('pointflow')
        self.root_dir = os.path.join(self.root, 'ShapeNetCore.v2.PC15k')
        # self.root_dir = root_dir
        logger.info('[DATA] cat: {}, split: {}, full path: {}; norm global={}, norm-box={}',
                    categories, split, self.root_dir, normalize_global, normalize_shape_box)

        self.split = split
        assert self.split in ['train', 'test', 'val']
        self.tr_sample_size = tr_sample_size
        self.te_sample_size = te_sample_size
        if type(categories) is str:
            categories = [categories]
        self.cates = categories

        if 'all' in categories:
            self.synset_ids = list(cate_to_synsetid.values())
        else:
            self.synset_ids = [cate_to_synsetid[c] for c in self.cates]
        subdirs = self.synset_ids
        # assert 'v2' in root_dir, "Only supporting v2 right now."
        self.gravity_axis = 1
        self.display_axis_order = [0, 2, 1]

        self.split = split
        self.in_tr_sample_size = tr_sample_size
        self.in_te_sample_size = te_sample_size
        self.subdirs = subdirs
        self.scale = scale
        self.random_subsample = random_subsample
        self.sample_with_replacement = sample_with_replacement
        self.input_dim = input_dim

        self.all_cate_mids = []
        self.cate_idx_lst = []
        self.all_points = []
        tic = time.time()
        for cate_idx, subd in enumerate(self.subdirs):
            # NOTE: [subd] here is synset id
            sub_path = os.path.join(self.root_dir, subd, self.split)
            if not os.path.isdir(sub_path):
                print("Directory missing : %s " % (sub_path))
                raise ValueError('check the data path')
                continue

            if True:
                all_mids = []
                assert(os.path.exists(sub_path)), f'path missing: {sub_path}'
                for x in os.listdir(sub_path):
                    if not x.endswith('.npy'):
                        continue
                    all_mids.append(os.path.join(self.split, x[:-len('.npy')]))

                logger.info('[DATA] number of file [{}] under: {} ',
                            len(os.listdir(sub_path)), sub_path)
                # NOTE: [mid] contains the split: i.e. "train/<mid>"
                # or "val/<mid>" or "test/<mid>"
                all_mids = sorted(all_mids)
                for mid in all_mids:
                    # obj_fname = os.path.join(sub_path, x)
                    if self.clip_forge_enable:
                        synset_id = subd
                        render_img_path = os.path.join(img_path, synset_id, mid.split('/')[-1], 'img_choy2016')
                        
                        #render_img_path = os.path.join(img_path, synset_id, mid.split('/')[-1])
                        #if not (os.path.exists(render_img_path)): continue
                        self.img_path.append(render_img_path)
                        assert(os.path.exists(render_img_path)), f'render img path not find: {render_img_path}'

                    obj_fname = os.path.join(self.root_dir, subd, mid + ".npy")
                    point_cloud = np.load(obj_fname)  # (15k, 3)
                    self.all_points.append(point_cloud[np.newaxis, ...])
                    self.cate_idx_lst.append(cate_idx)
                    self.all_cate_mids.append((subd, mid))

        logger.info('[DATA] Load data time: {:.1f}s | dir: {} | '
                    'sample_with_replacement: {}; num points: {}', time.time() - tic, self.subdirs,
                    self.sample_with_replacement, len(self.all_points))

        # Shuffle the index deterministically (based on the number of examples)
        self.shuffle_idx = list(range(len(self.all_points)))
        random.Random(38383).shuffle(self.shuffle_idx)
        self.cate_idx_lst = [self.cate_idx_lst[i] for i in self.shuffle_idx]
        self.all_points = [self.all_points[i] for i in self.shuffle_idx]
        self.all_cate_mids = [self.all_cate_mids[i] for i in self.shuffle_idx]
        if self.clip_forge_enable:
            self.img_path = [self.img_path[i] for i in self.shuffle_idx]

        # Normalization
        self.all_points = np.concatenate(self.all_points)  # (N, 15000, 3)
        self.normalize_per_shape = normalize_per_shape
        self.normalize_std_per_axis = normalize_std_per_axis
        self.recenter_per_shape = recenter_per_shape
        if self.normalize_shape_box:  # per shape normalization
            B, N = self.all_points.shape[:2]
            self.all_points_mean = (  # B,1,3
                (np.amax(self.all_points, axis=1)).reshape(B, 1, input_dim) +
                (np.amin(self.all_points, axis=1)).reshape(B, 1, input_dim)) / 2
            self.all_points_std = np.amax(  # B,1,1
                ((np.amax(self.all_points, axis=1)).reshape(B, 1, input_dim) -
                 (np.amin(self.all_points, axis=1)).reshape(B, 1, input_dim)),
                axis=-1).reshape(B, 1, 1) / 2
        elif self.normalize_per_shape:  # per shape normalization
            B, N = self.all_points.shape[:2]
            self.all_points_mean = self.all_points.mean(axis=1).reshape(
                B, 1, input_dim)
            logger.info('all_points shape: {}. mean over axis=1',
                        self.all_points.shape)
            
            if normalize_std_per_axis:
                self.all_points_std = self.all_points.reshape(
                    B, N, -1).std(axis=1).reshape(B, 1, input_dim)
            else:
                self.all_points_std = self.all_points.reshape(
                    B, -1).std(axis=1).reshape(B, 1, 1)
        elif all_points_mean is not None and all_points_std is not None and not self.recenter_per_shape:
            # using loaded dataset stats
            self.all_points_mean = all_points_mean
            self.all_points_std = all_points_std
            self.all_points_std_global = all_points_std_global
        elif self.recenter_per_shape:  # per shape center
            # TODO: bounding box scale at the large dim and center
            B, N = self.all_points.shape[:2]
            self.all_points_mean = (
                (np.amax(self.all_points, axis=1)).reshape(B, 1, input_dim) +
                (np.amin(self.all_points, axis=1)).reshape(B, 1,
                                                           input_dim)) / 2
            self.all_points_std = np.amax(
                ((np.amax(self.all_points, axis=1)).reshape(B, 1, input_dim) -
                 (np.amin(self.all_points, axis=1)).reshape(B, 1, input_dim)),
                axis=-1).reshape(B, 1, 1) / 2
        # else:  # normalize across the dataset
        elif normalize_global:  # normalize across the dataset
            self.all_points_mean = self.all_points.reshape(
                -1, input_dim).mean(axis=0).reshape(1, 1, input_dim)

            if normalize_std_per_axis:
                self.all_points_std = self.all_points.reshape(
                    -1, input_dim).std(axis=0).reshape(1, 1, input_dim)
                self.all_points_std_global = self.all_points.reshape(-1).std(
                    axis=0).reshape(1, 1, 1)
            else:
                self.all_points_std = self.all_points.reshape(-1).std(
                    axis=0).reshape(1, 1, 1)
                self.all_points_std_global = self.all_points_std

            logger.info('[DATA] normalize_global: mean={}, std={}',
                        self.all_points_mean.reshape(-1),
                        self.all_points_std.reshape(-1))
        else:
            raise NotImplementedError('No Normalization')
        self.all_points = (self.all_points - self.all_points_mean) / \
            self.all_points_std
        logger.info('[DATA] shape={}, all_points_mean:={}, std={}, max={:.3f}, min={:.3f}; num-pts={}',
                    self.all_points.shape,
                    self.all_points_mean.shape, self.all_points_std.shape,
                    self.all_points.max(), self.all_points.min(), tr_sample_size)

        if OVERFIT:
            self.all_points = self.all_points[:40]

        # TODO: why do we need this??
        self.train_points = self.all_points[:, :min(
            10000, self.all_points.shape[1])]  # subsample 15k points to 10k points per shape
        self.tr_sample_size = min(10000, tr_sample_size)
        self.te_sample_size = min(5000, te_sample_size)
        assert self.scale == 1, "Scale (!= 1) is deprecated"

        # Default display axis order
        self.display_axis_order = [0, 1, 2]

    def download_and_extract(self):
        import gdown
        dataset_path = os.path.join(self.root, 'ShapeNetCore.v2.PC15k')
        if not os.path.exists(dataset_path):
            os.makedirs(self.root, exist_ok=True)
            zip_path = os.path.join(self.root, 'ShapeNetCore.v2.PC15k.zip')
            if not os.path.isfile(zip_path):
                print("Downloading dataset...")
                gdown.download('https://drive.google.com/uc?id=1sw9gdk_igiyyt7MqALyxZhRrtPvAn0sX', zip_path, quiet=False)
            print("Extracting dataset...")
            with zipfile.ZipFile(zip_path, 'r') as zip_ref:
                zip_ref.extractall(self.root)
            os.remove(zip_path)
            print("Dataset downloaded and extracted successfully!")    

    def get_pc_stats(self, idx):
        if self.recenter_per_shape:
            m = self.all_points_mean[idx].reshape(1, self.input_dim)
            s = self.all_points_std[idx].reshape(1, -1)
            return m, s

        if self.normalize_per_shape or self.normalize_shape_box:
            m = self.all_points_mean[idx].reshape(1, self.input_dim)
            s = self.all_points_std[idx].reshape(1, -1)
            return m, s

        return self.all_points_mean.reshape(1, -1), \
            self.all_points_std.reshape(1, -1)
    
    def get_std_global(self, idx):
        if self.recenter_per_shape:
            m = self.all_points_mean[idx].reshape(1, self.input_dim)
            s = self.all_points_std[idx].reshape(1, -1)
            return m, s

        if self.normalize_per_shape or self.normalize_shape_box:
            m = self.all_points_mean[idx].reshape(1, self.input_dim)
            s = self.all_points_std[idx].reshape(1, -1)
            return m, s

        return self.all_points_mean.reshape(1, -1), \
            self.all_points_std.reshape(1, -1)

    def renormalize(self, mean, std):
        self.all_points = self.all_points * self.all_points_std + \
            self.all_points_mean
        self.all_points_mean = mean
        self.all_points_std = std
        self.all_points = (self.all_points - self.all_points_mean) / \
            self.all_points_std
        self.train_points = self.all_points[:, :min(
            10000, self.all_points.shape[1])]
        ## self.test_points = self.all_points[:, 10000:]

    def __len__(self):
        return len(self.train_points)

    def __getitem__(self, idx):
        output = {}
        tr_out = self.train_points[idx]
        if self.random_subsample and self.sample_with_replacement:
            tr_idxs = np.random.choice(tr_out.shape[0], self.tr_sample_size)
        elif self.random_subsample and not self.sample_with_replacement:
            tr_idxs = np.random.permutation(
                np.arange(tr_out.shape[0]))[:self.tr_sample_size]
        else:
            tr_idxs = np.arange(self.tr_sample_size)
        tr_out = torch.from_numpy(tr_out[tr_idxs, :]).float()
        m, s = self.get_pc_stats(idx)
        s_global = self.all_points_std_global[0,None]

        cate_idx = self.cate_idx_lst[idx]
        sid, mid = self.all_cate_mids[idx]
        input_pts = tr_out
    
        output.update(
            {
                'idx': idx,
                'select_idx': tr_idxs,
                'tr_points': tr_out,
                'input_pts': input_pts,
                'mean': m,
                'std': s,
                'std_global': s_global,
                'cate_idx': cate_idx,
                'sid': sid,
                'mid': mid,
                'display_axis_order': self.display_axis_order
            })

        # read image 
        if self.clip_forge_enable:
            img_path = self.img_path[idx]
            img_list = os.listdir(img_path) 
            img_list = [os.path.join(img_path, p) for p in img_list if 'jpg' in p or 'png' in p]
            assert(len(img_list) > 0), f'get empty list at {img_path}: {os.listdir(img_path)}'
            # subset 5 image
            img_idx = np.random.choice(len(img_list), 5) 
            img_list = [img_list[o] for o in img_idx]
            img_list = [Image.open(img).convert('RGB') for img in img_list] 
            img_list = [self.clip_preprocess(img) for img in img_list]
            img_list = torch.stack(img_list, dim=0) # B,3,H,W  
            all_img = img_list 
            output['tr_img'] = all_img

        return output


def init_np_seed(worker_id):
    seed = torch.initial_seed()
    np.random.seed(seed % 4294967296)


def get_datasets(cfg, args):
    """
        cfg: config.data sub part 
    """
    if OVERFIT:
        random_subsample = 0
    else:
        random_subsample = cfg.random_subsample
    logger.info(f'get_datasets: tr_sample_size={cfg.tr_max_sample_points}, '
                f' te_sample_size={cfg.te_max_sample_points}; '
                f' random_subsample={random_subsample}'
                f' normalize_global={cfg.normalize_global}'
                f' normalize_std_per_axix={cfg.normalize_std_per_axis}'
                f' normalize_per_shape={cfg.normalize_per_shape}'
                f' recenter_per_shape={cfg.recenter_per_shape}'
                )
    kwargs = {}
    tr_dataset = ShapeNet15kPointClouds(
        categories=cfg.cates,
        split='train',
        tr_sample_size=cfg.tr_max_sample_points,
        te_sample_size=cfg.te_max_sample_points,
        sample_with_replacement=cfg.sample_with_replacement,
        scale=cfg.dataset_scale,  # root_dir=cfg.data_dir,
        normalize_shape_box=cfg.normalize_shape_box,
        normalize_per_shape=cfg.normalize_per_shape,
        normalize_std_per_axis=cfg.normalize_std_per_axis,
        normalize_global=cfg.normalize_global,
        recenter_per_shape=cfg.recenter_per_shape,
        random_subsample=random_subsample,
        clip_forge_enable=cfg.clip_forge_enable,
        clip_model=cfg.clip_model,
        **kwargs)

    eval_split = getattr(args, "eval_split", "val")
    # te_dataset has random_subsample as False, therefore not using sample_with_replacement
    te_dataset = ShapeNet15kPointClouds(
        categories=cfg.cates,
        split=eval_split,
        tr_sample_size=cfg.tr_max_sample_points,
        te_sample_size=cfg.te_max_sample_points,
        scale=cfg.dataset_scale,  # root_dir=cfg.data_dir,
        normalize_shape_box=cfg.normalize_shape_box,
        normalize_per_shape=cfg.normalize_per_shape,
        normalize_std_per_axis=cfg.normalize_std_per_axis,
        normalize_global=cfg.normalize_global,
        recenter_per_shape=cfg.recenter_per_shape,
        all_points_mean=tr_dataset.all_points_mean,
        all_points_std=tr_dataset.all_points_std,
        clip_forge_enable=cfg.clip_forge_enable,
        clip_model=cfg.clip_model,
    )
    return tr_dataset, te_dataset


def get_data_loaders(cfg, args):
    tr_dataset, te_dataset = get_datasets(cfg, args)
    kwargs = {}
    if args.distributed:
        kwargs['sampler'] = data.distributed.DistributedSampler(
            tr_dataset, shuffle=True)
    else:
        kwargs['shuffle'] = True
    if args.eval_trainnll:
        kwargs['shuffle'] = False
    train_loader = data.DataLoader(dataset=tr_dataset,
                                   batch_size=cfg.batch_size,
                                   num_workers=cfg.num_workers,
                                   drop_last=cfg.train_drop_last == 1,
                                   pin_memory=False, **kwargs)
    test_loader = data.DataLoader(dataset=te_dataset,
                                  batch_size=cfg.batch_size_test,
                                  shuffle=False,
                                  num_workers=cfg.num_workers,
                                  pin_memory=False,
                                  drop_last=False,
                                  )
    logger.info(
        f'[Batch Size] train={cfg.batch_size}, test={cfg.batch_size_test}; drop-last={cfg.train_drop_last}')
    loaders = {
        "test_loader": test_loader,
        'train_loader': train_loader,
    }
    return loaders

def collate_fn(batch):
    batch = torch.utils.data.default_collate(batch)
    # Normalized coordinates
    # pos = (batch['input_pts'] - batch['mean']) / batch['std']
    pos = batch['input_pts']
    # To graph format
    batch_idx = torch.arange(pos.shape[0]).repeat_interleave(pos.shape[1])
    pos = pos.view(-1, 3)
    x = torch.ones_like(pos[:, 0]).unsqueeze(1)
    # Return dictionary
    return {'pos': pos, 'x': x, 'vec': None, 'y': None, 'batch': batch_idx, 'edge_index': None, 'mean': batch['mean'], 'std': batch['std'], 'std_global': batch['std_global']}