import torch
import numpy as np
import math
import os

import sys
sys.path.append('..')
from utils.distance import hungarian, hungarian_batched_grads

def init_np_seed(worker_id):
    seed = torch.initial_seed()
    np.random.seed(seed % 4294967296)

def generate_squares(nsamples, npoints, params=None, return_params=False):

    if params is None:
        params = np.random.rand(nsamples, 4) # [0,1] (rotation, scale, x_center, y_center)
    else:
        params[:,0] = params[:,0] + 0.1*np.random.normal(size=len(params)) # rotation
        params[:,1] = params[:,1]*(0.95 + 0.1*np.random.rand(len(params))) # scale
        params[:,2] = params[:,2] + (0.1*np.random.rand(len(params))) # x_center
        params[:,3] = params[:,3] + (0.1*np.random.rand(len(params))) # y_center


    r = np.random.rand(nsamples, npoints)[:, :, None] * 4 # [0, 4]
    u = (np.floor(r) + params[:, 0][:, None, None]) * math.pi * 0.5 # [0, 2*pi]
    v = u + math.pi * 0.5 # [pi, 3*pi]
    w = r - np.floor(r) # [0, 1]

    x = np.concatenate(
        (np.cos(u) * (1-w) + np.cos(v) * w,
        np.sin(u) * (1-w) + np.sin(v) * w),
        2)

    x = (x * (0.5 * params[:, 1:2, None] + 0.5)) + params[:, None, 2:] - 0.5

    if not return_params:
        return x
    else:
        return x, params

# funtion to generate uniformly?? distributed points in a circle
def generate_circles(num_samples, npoints, params=None, return_params=False):

    if params is None:
        params = np.random.rand(num_samples, 3)  # N x 3 (x_center, y_center, radius)
    else:
        params[:,0] = params[:,0] + (0.1*np.random.rand(len(params))) # x_center
        params[:,1] = params[:,1] + (0.1*np.random.rand(len(params))) # y_center
        params[:,2] = params[:,2] * (0.95 + 0.1*np.random.rand(len(params))) # radius

    theta = np.random.rand(num_samples, npoints) * 2 * np.pi # N x 200
    x = params[:, 2:] * np.cos(theta) # N x 200
    y = params[:, 2:] * np.sin(theta) # N x 200
    data = np.stack([x, y], axis=-1) + params[:, :2].reshape(-1, 1, 2)

    if not return_params:
        return data
    else:
        return data, params


def generate_paired_data(npairs, nsamples=None, data=None, categories=['circle', 'square'], npoints=200):
    if data is None:
        assert nsamples is not None
        data = []
        for category in categories:
            if category == 'circle':
                # print('Generating {} circles'.format(nsamples//len(categories)))
                data.append(generate_circles(nsamples//len(categories), npoints))
            elif category == 'square':
                # print('Generating {} squares'.format(nsamples//len(categories)))
                data.append(generate_squares(nsamples//len(categories), npoints))
            else:
                raise ValueError('Unknown category: {}'.format(category))
                
        data = np.concatenate(data, axis=0)
    else:
        nsamples = len(data)
        
    idxs = np.arange(nsamples)
    source_idxs = np.random.choice(idxs, npairs)
    target_idxs = np.random.choice(idxs, npairs)

    return data[source_idxs], data[target_idxs]

def generate_paired_data_mixed(npairs, nsamples=None, data=None, categories=['circle', 'square'], npoints=200):
    if data is None:
        assert nsamples is not None
        data = []
        for category in categories:
            if category == 'circle':
                # print('Generating {} circles'.format(nsamples//len(categories)))
                data.append(generate_circles(nsamples//len(categories), npoints))
            elif category == 'square':
                # print('Generating {} squares'.format(nsamples//len(categories)))
                data.append(generate_squares(nsamples//len(categories), npoints))
            else:
                raise ValueError('Unknown category: {}'.format(category))
                
        data = np.concatenate(data, axis=0)
    else:
        nsamples = len(data)
        
    idxs = np.arange(nsamples)
    all_source, all_target = [], []

    num_parts = 5

    # 1 (s,t)
    source_idxs = np.random.choice(idxs, npairs//num_parts)
    target_idxs = np.random.choice(idxs, npairs//num_parts)
    source, target = data[source_idxs], data[target_idxs] 
    all_source.append(source), all_target.append(target)

    # 2 (s,n)
    source_idxs = np.random.choice(idxs, npairs//num_parts)
    scale = 0.1 + np.random.rand(npairs//num_parts)[..., None, None]
    source, target = data[source_idxs], scale*np.random.normal(size=(npairs//num_parts, npoints, 2))
    all_source.append(source), all_target.append(target)

    # 3 (s, s+n)
    source_idxs = np.random.choice(idxs, npairs//num_parts)
    scale = 2*np.random.rand(npairs//num_parts)[..., None, None]
    source, target = data[source_idxs], data[source_idxs] + scale*np.random.normal(size=(npairs//num_parts, npoints, 2))
    all_source.append(source), all_target.append(target)

    # 4 (s, ~s+n )
    source, params = generate_circles(npairs//(num_parts*2), npoints, return_params=True)
    scale = 2*np.random.rand(npairs//(num_parts*2))[..., None, None] 
    target = generate_circles(npairs//(num_parts*2), npoints, params=params) + scale*np.random.normal(size=(npairs//(num_parts*2), npoints, 2))
    all_source.append(source), all_target.append(target)

    source, params = generate_squares(npairs//(num_parts*2), npoints, return_params=True) 
    scale = 2*np.random.rand(npairs//(num_parts*2))[..., None, None] 
    target = generate_squares(npairs//(num_parts*2), npoints, params=params) + scale*np.random.normal(size=(npairs//(num_parts*2), npoints, 2))
    all_source.append(source), all_target.append(target)

    # 5 (s, t+n )
    source_idxs = np.random.choice(idxs, npairs//num_parts)
    scale = 2*np.random.rand(npairs//num_parts)[..., None, None]
    source, target = data[source_idxs], data[target_idxs] + scale*np.random.normal(size=(npairs//num_parts, npoints, 2))
    all_source.append(source), all_target.append(target)

    return np.concatenate(all_source, axis=0), np.concatenate(all_target, axis=0)  

class SyntheticDataset(torch.utils.data.Dataset):
    
        def __init__(self, nsamples, npairs, npoints=200, categories=['circle', 'square'], augment=False):

            dirname = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'cache', 'val', *categories, f'npairs-{npairs}', f'num_pt-{npoints}-aug-{augment}')
            filename = os.path.join(dirname, 'distances.pt')

            if os.path.isfile(filename):
                print("Loading distances from cache")
                data = torch.load(filename)
                self.dists, self.matchings, self.matching_t2s = data['dist'], data['matchings'], data['matching_t2s']
                points = torch.load(filename.replace('distances', 'data'))
                self.source, self.target = points['source'], points['target']
                grads = torch.load(filename.replace('distances', 'grads'))
                self.grads = grads['grads']

            else:
                self.augment = augment

                if self.augment:
                    self.source, self.target = generate_paired_data_mixed(npairs=npairs, nsamples=nsamples, categories=categories, npoints=npoints)
                else:
                    self.source, self.target = generate_paired_data(npairs=npairs, nsamples=nsamples, categories=categories, npoints=npoints)
        
                self.source = torch.from_numpy(self.source).float()
                self.target = torch.from_numpy(self.target).float()
        
                print("Computing distances")
                self.dists, self.grads, self.matchings = hungarian_batched_grads(self.source, self.target, return_matching=True)
                self.matchings = torch.cat([match[1].unsqueeze(0) for match in torch.tensor(self.matchings)])
                self.matching_t2s = torch.sort(self.matchings, dim=1)[1]

                os.makedirs(dirname, exist_ok=True)
                print("Saving distances to cache")
                torch.save({'dist': self.dists, 'matchings':self.matchings, 'matching_t2s':self.matching_t2s}, filename)
                print('Saving data')
                torch.save({'source': self.source, 'target':self.target}, filename.replace('distances', 'data'))
                print('Saving grads')
                torch.save({'grads': self.grads}, filename.replace('distances', 'grads'))
    
        def __len__(self):
            return len(self.dists)
    
        def __getitem__(self, idx):
            sample = {'source': self.source[idx], 
                      'target': self.target[idx], 
                      'dist': self.dists[idx], 
                      'matching': self.matchings[idx],
                      'matching_t2s': self.matching_t2s[idx],
                      'grads': self.grads[idx]}
            return sample   

def build(nsamples, npairs, npoints, categories, batch_size, train_val_ratio, *args, **kwargs):
    full_dataset = SyntheticDataset(nsamples=nsamples,
                            npairs=npairs,
                            npoints=npoints,
                            categories=categories,
                            augment=kwargs.get('augment', False) 
                            )
    
    train_data_size = int(len(full_dataset)*train_val_ratio)
    train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_data_size, len(full_dataset) - train_data_size], generator=torch.Generator().manual_seed(42))

    train_sampler = torch.utils.data.RandomSampler(train_dataset)

    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=(train_sampler is None),
                            pin_memory=True, sampler=train_sampler, drop_last=True, 
                            collate_fn=None, worker_init_fn=init_np_seed) # num_workers=args.num_workers,

    val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False,
                            pin_memory=True, drop_last=True,
                            collate_fn=None, worker_init_fn=init_np_seed) # num_workers=args.num_workers,

    # full_loader = torch.utils.data.DataLoader(dataset=full_dataset, batch_size=4*batch_size, shuffle=False,
    #                          pin_memory=True, drop_last=True,
    #                          collate_fn=None, worker_init_fn=init_np_seed) # num_workers=args.num_workers,


    return train_dataset, val_dataset, train_loader, val_loader
 
if __name__ == '__main__':
    dataset = SyntheticDataset(nsamples=1000, npairs=1000, npoints=5, categories=['circle', 'square'])
    print(dataset[0])
    print(len(dataset[0]))