
import torch
import numpy as np
import os
import random
from tqdm import tqdm
import h5py
import json

import sys
# sys.path.append('..')
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), '..'))

from utils.distance import hungarian, hungarian_batched, hungarian_batched_grads

cates = ['airplane', 'bathtub', 'bed', 'bench', 'bookshelf', 'bottle', 'bowl', 'car', 'chair', 'cone', 'cup', 'curtain', 'desk', 'door', 'dresser', 'flower_pot', 'glass_box', 'guitar', 'keyboard', 'lamp', 
         'laptop', 'mantel', 'monitor', 'night_stand', 'person', 'piano', 'plant', 'radio', 'range_hood', 'sink', 'sofa', 'stairs', 'stool', 'table', 'tent', 'toilet', 'tv_stand', 'vase', 'wardrobe', 'xbox']


def init_np_seed(worker_id):
    seed = torch.initial_seed()
    np.random.seed(seed % 4294967296)

class ModelNet40(torch.utils.data.Dataset):
    def __init__(self, root, tr_sample_size=10000, split='train', random_subsample=False,
                 input_dim=3, npairs=10000, all_points_mean=None, all_points_std=None):
        
        self.root = root

        self.split = split
        assert self.split in ['train', 'test', 'val']
        self.in_tr_sample_size = tr_sample_size

        self.random_subsample = random_subsample
        self.input_dim = input_dim

        if split == 'train':
            self.max = tr_sample_size
            self.files = [os.path.join(root, x) for x in os.listdir(root) if x.startswith('ply_data_train') and not x.startswith('ply_data_train_')]
        elif split == 'val':
            self.max = tr_sample_size
            self.files = [os.path.join(root, x) for x in os.listdir(root) if x.startswith('ply_data_test') and not x.startswith('ply_data_test_')]
        else:
            self.max = tr_sample_size
            self.files = [os.path.join(root, x) for x in os.listdir(root) if x.startswith('ply_data_test') and not x.startswith('ply_data_test_')]

        self.cate_idx_lst = []
        self.all_points = []

        for file in self.files:
            print(file)
            f = h5py.File(file, 'r')
            point_cloud = f.get('data')[:]
            print(point_cloud.shape)
            self.all_points.append(point_cloud)

            id_file = file[:-4] + '_' + file[-4] + '_id2file.json'

            with open(id_file, 'r') as f:
                id2file = json.load(f)
            
            cate_idx = [cates.index(x.split('/')[0]) for x in id2file]
            self.cate_idx_lst.extend(cate_idx)

        self.all_points = np.concatenate(self.all_points, axis=0)  # (N, 2048, 3)

        # Shuffle the index deterministically (based on the number of examples)
        self.shuffle_idx = list(range(len(self.all_points)))
        random.Random(38383).shuffle(self.shuffle_idx)
        print("Shuffling the dataset with seed 38383", split)
        print(self.shuffle_idx[:10])
        print('-'*80)
        self.cate_idx_lst = [self.cate_idx_lst[i] for i in self.shuffle_idx]
        # self.all_points = [self.all_points[i] for i in self.shuffle_idx]

        # self.all_points = np.concatenate(self.all_points)  # (N, 2048, 3)
        self.all_points = torch.from_numpy(self.all_points).float()
        
        if all_points_mean is not None:
            assert all_points_std is not None
            self.all_points_mean = all_points_mean
            self.all_points_std = all_points_std
        else:
            self.all_points_mean = self.all_points.reshape(-1, input_dim).mean(axis=0).reshape(1, 1, input_dim)
            self.all_points_std = self.all_points.reshape(-1).std(axis=0).reshape(1, 1, 1)
            
        self.all_points = (self.all_points - self.all_points_mean) / self.all_points_std

        self.train_points = self.all_points

        self.tr_sample_size = min(10000, tr_sample_size)
        
        print(f'Total number of examples (point clouds): {len(self.train_points)}')
        print(f'Min number of points: ({split}) {self.tr_sample_size}')

        # tr_out = self.train_points[idx]
        rng = np.random.default_rng(seed=38383)

        if self.random_subsample:
            # tr_idxs = rng.choice(self.train_points.shape[1], (self.train_points.shape[0], self.tr_sample_size))
            tr_idxs = np.argsort(rng.normal(size=(self.train_points.shape[:2]))[:, :self.tr_sample_size]) # B x N
        else:
            tr_idxs = np.arange(self.tr_sample_size)
            tr_idxs = np.repeat(tr_idxs[np.newaxis, ...], self.train_points.shape[0], 0)

        self.train_points = self.train_points[torch.arange(len(self.train_points))[..., None], tr_idxs]

        all_source, all_target = [], []

        num_parts = 4

        rng = np.random.default_rng(seed=38383)
        gen = torch.Generator()
        gen.manual_seed(38383)

        ## Building pairs
        # 1 (s, t)
        idxs = np.arange(self.train_points.shape[0])
        source_idxs = rng.choice(idxs, npairs//num_parts, replace=True)
        target_idxs = rng.choice(idxs, npairs//num_parts, replace=True)

        source = self.train_points[source_idxs]
        target = self.train_points[target_idxs]

        all_source.append(source), all_target.append(target)

        # 2 (s,n)
        source_idxs = rng.choice(idxs, npairs//num_parts, replace=True)
        scale = 0.1 + torch.rand(npairs//num_parts, generator=gen)[..., None, None]
        source, target = self.train_points[source_idxs], scale*torch.randn(*self.train_points[source_idxs].shape)
        all_source.append(source), all_target.append(target)

        # 3 (s, s+n)
        source_idxs = rng.choice(idxs, npairs//num_parts, replace=True)
        scale = 0.1 + 2*torch.rand(npairs//num_parts, generator=gen)[..., None, None]
        source, target = self.train_points[source_idxs], self.train_points[source_idxs] + scale*torch.randn(*self.train_points[source_idxs].shape)
        all_source.append(source), all_target.append(target)

        # # 4 (s, ~s+n)
        # source_idxs = rng.choice(idxs, npairs//num_parts, replace=True)
        # scale = 0.1 + 2*torch.rand(npairs//num_parts, generator=gen)[..., None, None]
        # source, target = self.train_points[source_idxs], self.test_points[source_idxs] + scale*torch.randn(*self.train_points[source_idxs].shape)
        # all_source.append(source), all_target.append(target)


        # 5 (s, t+n)
        source_idxs = rng.choice(idxs, npairs//num_parts, replace=True)
        scale = 0.1 + 2*torch.rand(npairs//num_parts, generator=gen)[..., None, None]
        source, target = self.train_points[source_idxs], self.train_points[target_idxs] + scale*torch.randn(*self.train_points[target_idxs].shape)
        all_source.append(source), all_target.append(target)

        ## 
        self.source, self.target = torch.cat(all_source, dim=0), torch.cat(all_target, dim=0)

        dirname = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'cache', 'ModelNet40', split, f'npairs-{npairs}', f'num_pt-{self.tr_sample_size}')
        filename = os.path.join(dirname, 'distances.pt')

        if os.path.isfile(filename):
            print("Loading distances from cache")
            data = torch.load(filename)
            self.dists, self.matchings, self.matching_t2s = data['dist'], data['matchings'], data['matching_t2s']
            points = torch.load(filename.replace('distances', 'data'))
            self.source, self.target = points['source'], points['target']
            grads = torch.load(filename.replace('distances', 'grads'))
            self.grads = grads['grads']
        else:
            print("Computing distances")
            print(self.source.shape, self.target.shape)
            self.dists, self.grads, self.matchings = hungarian_batched_grads(self.source, self.target, return_matching=True)
            self.matchings = torch.cat([match[1].unsqueeze(0) for match in torch.tensor(self.matchings)])
            self.matching_t2s = torch.sort(self.matchings, dim=1)[1]

            os.makedirs(dirname, exist_ok=True)
            print("Saving distances to cache")
            torch.save({'dist': self.dists, 'matchings':self.matchings, 'matching_t2s':self.matching_t2s}, filename)
            print('Saving data')
            torch.save({'source': self.source, 'target':self.target}, filename.replace('distances', 'data'))
            print('Saving grads')
            torch.save({'grads': self.grads}, filename.replace('distances', 'grads'))

    def __len__(self):
        return len(self.dists)

    def __getitem__(self, idx):

        sample = {'source': self.source[idx], 
                  'target': self.target[idx], 
                  'dist': self.dists[idx], 
                  'matching': self.matchings[idx],
                  'matching_t2s': self.matching_t2s[idx],
                  'grads': self.grads[idx]
                  }

        return sample

    # def __getitem__(self, idx):
    #     tr_out = self.train_points[idx]
    #     te_out = self.test_points[idx]

    #     tr_ofs = tr_out.mean(0, keepdim=True)
    #     te_ofs = te_out.mean(0, keepdim=True)

    #     if self.standardize_per_shape:
    #         # If standardize_per_shape, centering in/out
    #         tr_out -= tr_ofs
    #         te_out -= te_ofs
    #     if self.random_offset:
    #         # scale data offset
    #         if random.uniform(0., 1.) < 0.2:
    #             scale = random.uniform(1., 1.5)
    #             tr_out -= tr_ofs
    #             te_out -= te_ofs
    #             tr_ofs *= scale
    #             te_ofs *= scale
    #             tr_out += tr_ofs
    #             te_out += te_ofs

    #     m, s = self.get_pc_stats(idx)
    #     m, s = torch.from_numpy(np.asarray(m)), torch.from_numpy(np.asarray(s))
    #     cate_idx = self.cate_idx_lst[idx]
    #     sid, mid = self.all_cate_mids[idx]

    #     return {
    #         'idx': idx,
    #         'set': tr_out if self.split == 'train' else te_out,
    #         'offset': tr_ofs if self.split == 'train' else te_ofs,
    #         'mean': m, 'std': s, 'label': cate_idx,
    #         'sid': sid, 'mid': mid
    #     }

def build(root, tr_sample_size, random_subsample, npairs, batch_size, *args, **kwargs):

    train_dataset = ModelNet40(root=root, 
                            tr_sample_size=tr_sample_size, 
                            split='train',
                            random_subsample=random_subsample, 
                            input_dim=3, 
                            npairs=npairs
                            )

    val_dataset = ModelNet40(root=root, 
                            tr_sample_size=tr_sample_size, 
                            split='val',
                            random_subsample=False, 
                            input_dim=3, 
                            npairs=npairs,
                            all_points_mean=train_dataset.all_points_mean,
                            all_points_std=train_dataset.all_points_std
                            )
    
    train_sampler = torch.utils.data.RandomSampler(train_dataset)

    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=(train_sampler is None),
                            pin_memory=True, sampler=train_sampler, drop_last=True, 
                            collate_fn=None, worker_init_fn=init_np_seed)

    val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False,
                            pin_memory=True, drop_last=True,
                            collate_fn=None, worker_init_fn=init_np_seed)


    return train_dataset, val_dataset, train_loader, val_loader
 
if __name__ == '__main__':
    dataset = ModelNet40(root='~/data/ModelNet/modelnet40_ply_hdf5_2048/', 
                        tr_sample_size=1024, 
                        split='val',
                        random_subsample=False, 
                        input_dim=3, 
                        npairs=25000) #TODO #SyntheticDataset(nsamples=1000, npairs=1000, npoints=5, categories=['circle', 'square'])
    # print(dataset[0])
    # print(len(dataset[0]))
    print(dataset[0]['source'].shape, dataset[0]['target'].shape, dataset[0]['dist'])