import numpy as np
import torch
import torch.utils.data
from os import listdir
from os.path import isfile, join
import json
import pickle
import torch.utils.data as data_utils
from torch.nn.utils.rnn import pad_sequence
import time
import shutil
from easydict import EasyDict as edict

import os
import glob
import trimesh
import numpy as np
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

PAD = 0
# num_types = 40

import re

def atoi(text):
    return int(text) if text.isdigit() else text

def natural_keys(text):
    '''
    alist.sort(key=natural_keys) sorts in human order
    http://nedbatchelder.com/blog/200712/human_sorting.html
    (See Toothy's implementation in the comments)
    '''
    return [ atoi(c) for c in re.split(r'(\d+)', text) ]


def collate_fn(insts):
	point, label, object_name = list(zip(*insts))
	
	point = torch.tensor(point)
	point = point - torch.mean(point, dim=-2, keepdim=True)
	norm = torch.max(torch.linalg.norm(point, dim=-1, keepdim=True), dim=-2, keepdim=True)[0]
	point = point / norm

	label = torch.tensor(label)
	return point, label, object_name


class ModelNetData:
	def __init__(self, location, split):
		self.files = self.get_files(join(location, split))
		#if split == "test":
		#	self.files.sort(key=natural_keys)
		
	def __len__(self):
		return len(self.files)

	def __getitem__(self, idx):
		return self.read_file(self.files[idx])
		
	def read_file(self, f):
		with open(f) as fhandle:
			data = json.load(fhandle)
		return data['point'], data['label'], data['object']

	def get_files(self, mypath):
		onlyfiles = [join(mypath, f) for f in listdir(mypath) if isfile(join(mypath, f))]
		return onlyfiles


def get_dataloader(args, N=None):
	if args.dataset == "ModelNet10":
		data_dir = "./data/ModelNet10_1024/"
	elif args.dataset == "ModelNet40":
		data_dir = "./data/ModelNet40_1024/"
	elif args.dataset == "MNIST":
		data_dir = "./data/MNIST/"

	if args.amount_data == "50%":
		ds_train_full = ModelNetData(data_dir, 'train_50%') # you can only change here to change the amount of data: train_1%, train_5%, train_10%, ...
	elif args.amount_data == "25%":
		ds_train_full = ModelNetData(data_dir, 'train_25%') # you can only change here to change the amount of data: train_1%, train_5%, train_10%, ...
	elif args.amount_data == "10%":
		ds_train_full = ModelNetData(data_dir, 'train_10%') # you can only change here to change the amount of data: train_1%, train_5%, train_10%, ...
	elif args.amount_data == "5%":
		ds_train_full = ModelNetData(data_dir, 'train_5%') # you can only change here to change the amount of data: train_1%, train_5%, train_10%, ...
	elif args.amount_data == "1%":
		ds_train_full = ModelNetData(data_dir, 'train_1%') # you can only change here to change the amount of data: train_1%, train_5%, train_10%, ...
	else:
		ds_train_full = ModelNetData(data_dir, 'train') # you can only change here to change the amount of data: train_1%, train_5%, train_10%, ...

	ds_test = ModelNetData(data_dir, 'test')

	if N is not None:
		print('############### Using Subset of Data ##################')
		if len(N) != 3:
			raise NotImplementedError('Size of Array should be 3')

		if N[0] > ds_train_full.__len__():
			raise NotImplementedError('More samples than present in DS')

		indices = torch.tensor(np.random.choice(ds_train_full.__len__(), N[0], replace=False)) #torch.arange(N[0])

		if args.partial != 0:
			total_data = indices.shape[0]
			partial_indices = indices[:int(total_data * args.partial)]
			indices = indices[int(total_data * args.partial):]
			ds_partial_train = data_utils.Subset(ds_train_full, partial_indices)
			print("ds_partial_train len: ", ds_partial_train.__len__())

		ds_train = data_utils.Subset(ds_train_full, indices)
	else:
		ds_train = ds_train_full

	trainloader = torch.utils.data.DataLoader(
		ds_train,
		num_workers=args.num_workers,
		batch_size=args.train_batch,
		collate_fn=collate_fn,
		shuffle=True
	)

	testloader = torch.utils.data.DataLoader(
		ds_test,
		num_workers=args.num_workers,
		batch_size=args.test_batch,
		collate_fn=collate_fn,
		shuffle=False
	)

	print("ds_train_full len: ", ds_train_full.__len__())
	print("ds_train len: ", ds_train.__len__())
	print("ds_test len: ", ds_test.__len__())

	return trainloader, None, testloader

if __name__=='__main__':
	print("here")
