
import torch
import numpy as np
import os
import random
from tqdm import tqdm

import sys
# sys.path.append('..')
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), '..'))

from utils.distance import hungarian, hungarian_batched, hungarian_batched_grads


def init_np_seed(worker_id):
    seed = torch.initial_seed()
    np.random.seed(seed % 4294967296)

# taken from https://github.com/optas/latent_3d_points/blob/8e8f29f8124ed5fc59439e8551ba7ef7567c9a37/src/in_out.py
synsetid_to_cate = {
    '02691156': 'airplane', '02773838': 'bag', '02801938': 'basket',
    '02808440': 'bathtub', '02818832': 'bed', '02828884': 'bench',
    '02876657': 'bottle', '02880940': 'bowl', '02924116': 'bus',
    '02933112': 'cabinet', '02747177': 'can', '02942699': 'camera',
    '02954340': 'cap', '02958343': 'car', '03001627': 'chair',
    '03046257': 'clock', '03207941': 'dishwasher', '03211117': 'monitor',
    '04379243': 'table', '04401088': 'telephone', '02946921': 'tin_can',
    '04460130': 'tower', '04468005': 'train', '03085013': 'keyboard',
    '03261776': 'earphone', '03325088': 'faucet', '03337140': 'file',
    '03467517': 'guitar', '03513137': 'helmet', '03593526': 'jar',
    '03624134': 'knife', '03636649': 'lamp', '03642806': 'laptop',
    '03691459': 'speaker', '03710193': 'mailbox', '03759954': 'microphone',
    '03761084': 'microwave', '03790512': 'motorcycle', '03797390': 'mug',
    '03928116': 'piano', '03938244': 'pillow', '03948459': 'pistol',
    '03991062': 'pot', '04004475': 'printer', '04074963': 'remote_control',
    '04090263': 'rifle', '04099429': 'rocket', '04225987': 'skateboard',
    '04256520': 'sofa', '04330267': 'stove', '04530566': 'vessel',
    '04554684': 'washer', '02992529': 'cellphone',
    '02843684': 'birdhouse', '02871439': 'bookshelf',
    # '02858304': 'boat', no boat in our dataset, merged into vessels
    # '02834778': 'bicycle', not in our taxonomy
}

cate_to_synsetid = {v: k for k, v in synsetid_to_cate.items()}
 
class SetMNIST(torch.utils.data.Dataset):
    def __init__(self, root, categories, tr_sample_size=10000, te_sample_size=10000, split='train',
                 random_offset=False, random_subsample=False, 
                 input_dim=3, npairs=10000, all_points_mean=None, all_points_std=None):
        
        self.root = root
        assert 'v2' in root, "Only supporting v2 right now."

        self.split = split
        assert self.split in ['train', 'test', 'val']
        self.in_tr_sample_size = tr_sample_size
        self.in_te_sample_size = te_sample_size

        self.cates = categories

        if 'all' in categories:
            self.subdirs = list(cate_to_synsetid.values())
        else:
            self.subdirs = [cate_to_synsetid[c] for c in self.cates]

        self.random_offset = random_offset
        self.random_subsample = random_subsample
        self.input_dim = input_dim
        if split == 'train':
            self.max = tr_sample_size
        elif split == 'val':
            self.max = te_sample_size
        else:
            self.max = max((tr_sample_size, te_sample_size))

        self.all_cate_mids = []
        self.cate_idx_lst = []
        self.all_points = []
        for cate_idx, subd in enumerate(self.subdirs):
            # NOTE: [subd] here is synset id
            sub_path = os.path.join(root, subd, self.split)
            if not os.path.isdir(sub_path):
                print("Directory missing : %s" % sub_path)
                continue

            all_mids = []
            for x in os.listdir(sub_path):
                if not x.endswith('.npy'):
                    continue
                all_mids.append(os.path.join(self.split, x[:-len('.npy')]))

            # NOTE: [mid] contains the split: i.e. "train/<mid>" or "val/<mid>" or "test/<mid>"
            for mid in tqdm(all_mids):
                # obj_fname = os.path.join(sub_path, x)
                obj_fname = os.path.join(root, subd, mid + ".npy")
                try:
                    point_cloud = np.load(obj_fname)  # (15k, 3)
                except:
                    continue

                assert point_cloud.shape[0] == 15000
                self.all_points.append(point_cloud[np.newaxis, ...])
                self.cate_idx_lst.append(cate_idx)
                self.all_cate_mids.append((subd, mid))

        # Shuffle the index deterministically (based on the number of examples)
        self.shuffle_idx = list(range(len(self.all_points)))
        random.Random(38383).shuffle(self.shuffle_idx)
        print("Shuffling the dataset with seed 38383", split)
        print(self.shuffle_idx[:10])
        print('-'*80)
        self.cate_idx_lst = [self.cate_idx_lst[i] for i in self.shuffle_idx]
        self.all_points = [self.all_points[i] for i in self.shuffle_idx]
        self.all_cate_mids = [self.all_cate_mids[i] for i in self.shuffle_idx]

        self.all_points = np.concatenate(self.all_points)  # (N, 15000, 3)
        self.all_points = torch.from_numpy(self.all_points).float()
        
        if all_points_mean is not None:
            assert all_points_std is not None
            self.all_points_mean = all_points_mean
            self.all_points_std = all_points_std
        else:
            self.all_points_mean = self.all_points.reshape(-1, input_dim).mean(axis=0).reshape(1, 1, input_dim)
            self.all_points_std = self.all_points.reshape(-1).std(axis=0).reshape(1, 1, 1)
            
        self.all_points = (self.all_points - self.all_points_mean) / self.all_points_std

        self.train_points = self.all_points[:, :10000]
        self.test_points = self.all_points[:, 10000:]

        self.tr_sample_size = min(10000, tr_sample_size)
        self.te_sample_size = min(5000, te_sample_size)
        
        print(f'Total number of examples (point clouds): {len(self.train_points)}')
        print(f'Min number of points: (train) {self.tr_sample_size} (test) {self.te_sample_size}')

        # tr_out = self.train_points[idx]
        rng = np.random.default_rng(seed=38383)

        if self.random_subsample:
            # tr_idxs = rng.choice(self.train_points.shape[1], (self.train_points.shape[0], self.tr_sample_size))
            tr_idxs = np.argsort(rng.normal(size=(self.train_points.shape[:2]))[:, :self.tr_sample_size]) # B x N
        else:
            tr_idxs = np.arange(self.tr_sample_size)
            tr_idxs = np.repeat(tr_idxs[np.newaxis, ...], self.train_points.shape[0], 0)

        self.train_points = self.train_points[torch.arange(len(self.train_points))[..., None], tr_idxs]

        # self.train_points = torch.stack([self.train_points[i, idx] for i, idx in enumerate(tr_idxs)], dim=0) 

        if self.random_subsample:
            te_idxs = np.argsort(rng.normal(size=(self.test_points.shape[:2]))[:, :self.te_sample_size]) # B x N
        else:
            te_idxs = np.arange(self.te_sample_size)
            te_idxs = np.repeat(te_idxs[np.newaxis, ...], self.test_points.shape[0], 0)

        self.test_points = self.test_points[torch.arange(len(self.test_points))[..., None], te_idxs]
        # self.test_points = torch.stack([self.test_points[i, idx] for i, idx in enumerate(te_idxs)], dim=0) 

        all_source, all_target = [], []

        num_parts = 5

        rng = np.random.default_rng(seed=38383)
        gen = torch.Generator()
        gen.manual_seed(38383)

        ## Building pairs
        # 1 (s, t)
        idxs = np.arange(self.train_points.shape[0])
        source_idxs = rng.choice(idxs, npairs//num_parts, replace=True)
        target_idxs = rng.choice(idxs, npairs//num_parts, replace=True)

        source = self.train_points[source_idxs]
        target = self.train_points[target_idxs]

        all_source.append(source), all_target.append(target)

        # 2 (s,n)
        source_idxs = rng.choice(idxs, npairs//num_parts, replace=True)
        scale = 0.1 + torch.rand(npairs//num_parts, generator=gen)[..., None, None]
        source, target = self.train_points[source_idxs], scale*torch.randn(*self.train_points[source_idxs].shape)
        all_source.append(source), all_target.append(target)

        # 3 (s, s+n)
        source_idxs = rng.choice(idxs, npairs//num_parts, replace=True)
        scale = 0.1 + 2*torch.rand(npairs//num_parts, generator=gen)[..., None, None]
        source, target = self.train_points[source_idxs], self.train_points[source_idxs] + scale*torch.randn(*self.train_points[source_idxs].shape)
        all_source.append(source), all_target.append(target)

        # 4 (s, ~s+n)
        # source_idxs = rng.choice(idxs, npairs//num_parts, replace=True)
        # scale = 0.1 + 2*torch.rand(npairs//num_parts, generator=gen)[..., None, None]
        # source, target = self.train_points[source_idxs], self.test_points[source_idxs] + scale*torch.randn(*self.train_points[source_idxs].shape)
        # all_source.append(source), all_target.append(target)


        # 5 (s, t+n)
        source_idxs = rng.choice(idxs, npairs//num_parts, replace=True)
        scale = 0.1 + 2*torch.rand(npairs//num_parts, generator=gen)[..., None, None]
        source, target = self.train_points[source_idxs], self.train_points[target_idxs] + scale*torch.randn(*self.train_points[target_idxs].shape)
        all_source.append(source), all_target.append(target)

        ## 
        self.source, self.target = torch.cat(all_source, dim=0), torch.cat(all_target, dim=0)

        dirname = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'cache', split,*categories, f'npairs-{npairs}', f'num_pt-{self.tr_sample_size}-{self.te_sample_size}')
        filename = os.path.join(dirname, 'distances.pt')

        if os.path.isfile(filename):
            print("Loading distances from cache")
            data = torch.load(filename)
            self.dists, self.matchings, self.matching_t2s = data['dist'], data['matchings'], data['matching_t2s']
            points = torch.load(filename.replace('distances', 'data'))
            self.source, self.target = points['source'], points['target']
            grads = torch.load(filename.replace('distances', 'grads'))
            self.grads = grads['grads']
        else:
            print("Computing distances")
            self.dists, self.grads, self.matchings = hungarian_batched_grads(self.source, self.target, return_matching=True)
            self.matchings = torch.cat([match[1].unsqueeze(0) for match in torch.tensor(self.matchings)])
            self.matching_t2s = torch.sort(self.matchings, dim=1)[1]

            os.makedirs(dirname, exist_ok=True)
            print("Saving distances to cache")
            torch.save({'dist': self.dists, 'matchings':self.matchings, 'matching_t2s':self.matching_t2s}, filename)
            print('Saving data')
            torch.save({'source': self.source, 'target':self.target}, filename.replace('distances', 'data'))
            print('Saving grads')
            torch.save({'grads': self.grads}, filename.replace('distances', 'grads'))

    def __len__(self):
        return len(self.dists)

    def __getitem__(self, idx):

        sample = {'source': self.source[idx], 
                  'target': self.target[idx], 
                  'dist': self.dists[idx], 
                  'matching': self.matchings[idx],
                  'matching_t2s': self.matching_t2s[idx],
                  'grads': self.grads[idx]
                  }

        return sample  

    # def __getitem__(self, idx):
    #     tr_out = self.train_points[idx]
    #     te_out = self.test_points[idx]

    #     tr_ofs = tr_out.mean(0, keepdim=True)
    #     te_ofs = te_out.mean(0, keepdim=True)

    #     if self.standardize_per_shape:
    #         # If standardize_per_shape, centering in/out
    #         tr_out -= tr_ofs
    #         te_out -= te_ofs
    #     if self.random_offset:
    #         # scale data offset
    #         if random.uniform(0., 1.) < 0.2:
    #             scale = random.uniform(1., 1.5)
    #             tr_out -= tr_ofs
    #             te_out -= te_ofs
    #             tr_ofs *= scale
    #             te_ofs *= scale
    #             tr_out += tr_ofs
    #             te_out += te_ofs

    #     m, s = self.get_pc_stats(idx)
    #     m, s = torch.from_numpy(np.asarray(m)), torch.from_numpy(np.asarray(s))
    #     cate_idx = self.cate_idx_lst[idx]
    #     sid, mid = self.all_cate_mids[idx]

    #     return {
    #         'idx': idx,
    #         'set': tr_out if self.split == 'train' else te_out,
    #         'offset': tr_ofs if self.split == 'train' else te_ofs,
    #         'mean': m, 'std': s, 'label': cate_idx,
    #         'sid': sid, 'mid': mid
    #     }

def build(root, categories, tr_sample_size, te_sample_size, 
          random_offset, random_subsample, npairs, batch_size, 
          *args, **kwargs):

    train_dataset = Shapes3D(root=root, 
                            categories=categories, 
                            tr_sample_size=tr_sample_size, 
                            te_sample_size=te_sample_size, 
                            split='train',
                            random_offset=random_offset, 
                            random_subsample=random_subsample, 
                            input_dim=3, 
                            npairs=npairs
                            )

    val_dataset = Shapes3D(root=root, 
                            categories=categories, 
                            tr_sample_size=tr_sample_size, 
                            te_sample_size=te_sample_size, 
                            split='val',
                            random_offset=False, 
                            random_subsample=False, 
                            input_dim=3, 
                            npairs=npairs//5,
                            all_points_mean=train_dataset.all_points_mean,
                            all_points_std=train_dataset.all_points_std
                            )
    
    # train_data_size = int(len(full_dataset)*train_val_ratio)
    # train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_data_size, len(full_dataset) - train_data_size], generator=torch.Generator().manual_seed(42))

    train_sampler = torch.utils.data.RandomSampler(train_dataset)

    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=(train_sampler is None),
                            pin_memory=True, sampler=train_sampler, drop_last=True, 
                            collate_fn=None, worker_init_fn=init_np_seed) # num_workers=args.num_workers,

    val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False,
                            pin_memory=True, drop_last=True,
                            collate_fn=None, worker_init_fn=init_np_seed) # num_workers=args.num_workers,

    # full_loader = torch.utils.data.DataLoader(dataset=full_dataset, batch_size=4*batch_size, shuffle=False,
    #                          pin_memory=True, drop_last=True,
    #                          collate_fn=None, worker_init_fn=init_np_seed) # num_workers=args.num_workers,


    return train_dataset, val_dataset, train_loader, val_loader
 
if __name__ == '__main__':
    dataset = Shapes3D(root='~/data/ShapeNet/ShapeNetCore.v2.PC15k', 
                        categories=['airplane'], 
                        tr_sample_size=500, 
                        te_sample_size=150, 
                        split='train',
                        random_offset=False, 
                        random_subsample=False, 
                        input_dim=3, 
                        npairs=150) #TODO #SyntheticDataset(nsamples=1000, npairs=1000, npoints=5, categories=['circle', 'square'])
    print(dataset[0])
    print(len(dataset[0]))
    print(dataset[0]['source'].shape, dataset[0]['target'].shape, dataset[0]['dist'].shape)