import torch
import numpy as np
import math

# import sys
# sys.path.append('..')
# from utils.distance import hungarian

def init_np_seed(worker_id):
    seed = torch.initial_seed()
    np.random.seed(seed % 4294967296)

def generate_squares(nsamples, npoints):
    a = np.random.rand(nsamples, 4)

    r = np.random.rand(nsamples, npoints)[:, :, None] * 4
    u = (np.floor(r) + a[:, 0][:, None, None]) * math.pi * 0.5
    v = u + math.pi * 0.5
    w = r - np.floor(r)

    x = np.concatenate(
        (np.cos(u) * (1-w) + np.cos(v) * w,
        np.sin(u) * (1-w) + np.sin(v) * w),
        2)

    x = (x * (0.5 * a[:, 1:2, None] + 0.5)) + a[:, None, 2:] - 0.5

    return x

# funtion to generate uniformly?? distributed points in a circle
def generate_circles(num_samples, npoints):
    centers = np.random.rand(num_samples, 1, 2) # N x 1 x 2
    r = 0.5 + 0.5*np.random.rand(num_samples)[..., None] # N
    theta = np.random.rand(num_samples, npoints) * 2 * np.pi # N x 200
    x = r * np.cos(theta) # N x 200
    y = r * np.sin(theta) # N x 200
    return np.stack([x, y], axis=-1) + centers

class SyntheticDatasetGen(torch.utils.data.Dataset):
    
        def __init__(self, nsamples, npoints=200, categories=['circle', 'square']):
            data = []
            for category in categories:
                if category == 'circle':
                    data.append(generate_circles(nsamples//len(categories), npoints))
                elif category == 'square':
                    data.append(generate_squares(nsamples//len(categories), npoints))
                else:
                    raise ValueError('Unknown category: {}'.format(category))
                    
            data = np.concatenate(data, axis=0)

            self.source = torch.from_numpy(data).float()
    
        def __len__(self):
            return len(self.source)
    
        def __getitem__(self, idx):
            set = self.source[idx]
            set_mask = torch.zeros(set.shape[0], dtype=torch.bool)
            cardinality = torch.tensor(set.shape[0], dtype=torch.long)

            sample = {'set': set, 'set_mask': set_mask, 'cardinality': cardinality}
            return sample   

def build(nsamples, npoints, categories, batch_size, train_val_ratio, *args, **kwargs):
    full_dataset = SyntheticDatasetGen(nsamples=nsamples,
                            npoints=npoints,
                            categories=categories
                            )
    
    train_data_size = int(len(full_dataset)*train_val_ratio)
    train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_data_size, len(full_dataset) - train_data_size], generator=torch.Generator().manual_seed(42))

    train_sampler = torch.utils.data.RandomSampler(train_dataset)

    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=(train_sampler is None),
                            pin_memory=True, sampler=train_sampler, drop_last=True, 
                            collate_fn=None, worker_init_fn=init_np_seed) # num_workers=args.num_workers,

    val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False,
                            pin_memory=True, drop_last=True,
                            collate_fn=None, worker_init_fn=init_np_seed) # num_workers=args.num_workers,

    # full_loader = torch.utils.data.DataLoader(dataset=full_dataset, batch_size=4*batch_size, shuffle=False,
    #                          pin_memory=True, drop_last=True,
    #                          collate_fn=None, worker_init_fn=init_np_seed) # num_workers=args.num_workers,


    return train_dataset, val_dataset, train_loader, val_loader
 
if __name__ == '__main__':
    dataset = SyntheticDatasetGen(nsamples=1000, npoints=5, categories=['circle', 'square'])
    print(dataset[0])
    print(len(dataset[0]))