from enum import Enum
from functools import partial
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

class SetTypes(Enum):
    ncircle3 = 0 
    ncircle6 = 'ncircle/dim6'
    random = 2
    mn_small = 3
    mn_large = 4
    rna = 5

    def return_path(self, small = True):
        if self is SetTypes.ncircle3:
            train_nmax = 300 if small else 500 
            train_nmin = 100 if small else 300
            train_sf = './data/sam/{}/data/train-nmax-{}-nmin-{}-sz-{}-1.npz'.format(self.return_dataset_name(),
                                                                      train_nmax,
                                                                      train_nmin,
                                                                      2000)
            val_sf = './data/sam/{}/data/val-nmax-{}-nmin-{}-sz-{}-1.npz'.format(self.return_dataset_name(),
                                                                            train_nmax,
                                                                            train_nmin,
                                                                            200)
        if self is SetTypes.ncircle6:
            dim_nmax = 300 if small else 500 
            dim_nmin = 100 if small else 300
            train_sf = './data/sam/{}/data/train-nmax-{}-nmin-{}-sz-{}-1.npz'.format(self.return_dataset_name(),
                                                                      dim_nmax,
                                                                      dim_nmin,
                                                                      3600)
            val_sf = './data/sam/{}/data/val-nmax-{}-nmin-{}-sz-{}-1.npz'.format(self.return_dataset_name(),
                                                                            dim_nmax,
                                                                            dim_nmin,
                                                                            400)                                                            
        elif self is SetTypes.random:
            dim_nmax = 257 if small else 300
            dim_nmin = 256 if small else 200
            train_sf = './data/sam/{}/data/train-nmax-{}-nmin-{}-sz-{}-1.npz'.format(f'random',
                                                                      dim_nmax,
                                                                      dim_nmin,
                                                                      3000)
            val_sf = './data/sam/{}/data/val-nmax-{}-nmin-{}-sz-{}-1.npz'.format(f'random',
                                                                            dim_nmax,
                                                                            dim_nmin,
                                                                            300)
        elif self is SetTypes.mn_small:
            dim_nmin = 20 if small else 300
            dim_nmax = 200 if small else 500
            train_sf = './data/data/chen_datasets/modelnet-small/train-modelnet-small.npz'
            val_sf = './data/data/chen_datasets/modelnet-small/val-modelnet-small.npz'
            train_sf = './data/sam/{}/data/train-nmax-{}-nmin-{}-sz-{}-1.npz'.format(f'modelnet',
                                                                    dim_nmax,
                                                                    dim_nmin,
                                                                    3000)
            val_sf = './data/sam/{}/data/val-nmax-{}-nmin-{}-sz-{}-1.npz'.format(f'modelnet',
                                                                            dim_nmax,
                                                                            dim_nmin,
                                                                                300)
        elif self is SetTypes.mn_large:
            train_nmax = 2049 if small else 2000
            train_nmin = 2048 if small else 1800
            train_sf = './data/data/chen_datasets/modelnet-large/train-modelnet-large.npz'
            val_sf = './data/data/chen_datasets/modelnet-large/val-modelnet-large.npz'
            
            train_sf = './data/sam/{}/data/train-nmax-{}-nmin-{}-sz-{}-1.npz'.format(f'modelnet',
                                                                      train_nmax,
                                                                      train_nmin,
                                                                      4000)
            val_sf = './data/sam/{}/data/val-nmax-{}-nmin-{}-sz-{}-1.npz'.format(f'modelnet',
                                                                            train_nmax,
                                                                            train_nmin,
                                                                            400)
            
        elif self is SetTypes.rna:
            dim_nmin = 20 if small else 300
            dim_nmax = 200 if small else 500
            train_sf = './data/sam/{}/data/train-nmax-{}-nmin-{}-sz-{}-1.npz'.format(f'rna',
                                                                      dim_nmax,
                                                                      dim_nmin,
                                                                      3000)
            val_sf = './data/sam/{}/data/val-nmax-{}-nmin-{}-sz-{}-1.npz'.format(f'rna',
                                                                            dim_nmax,
                                                                            dim_nmin,
                                                                            300)
            
        return train_sf, val_sf

    def return_dim(self):
        if self is SetTypes.ncircle3:
            return 3
        elif self is SetTypes.ncircle6:
            return 6
        elif self is SetTypes.random:
            return 2
        elif self is SetTypes.mn_small or self is SetTypes.mn_large:
            return 3
        elif self is SetTypes.rna:
            return 2000

    
    def return_dataset_name(self):
        if self is SetTypes.ncircle3:
            return f'ncircle/dim3'
        if self is SetTypes.ncircle6:
            return f'ncircle/dim6'   
        if self is SetTypes.random:
            return 'random'
        if self is SetTypes.mn_small:
            return 'mn_small'
        if self is SetTypes.mn_large:
            return 'mn_large'
        if self is SetTypes.rna:
            return 'rna'

def return_act(act_name:str,slope:float):
    if act_name == 'relu':
        func = nn.ReLU
    elif act_name == 'lrelu':
        func =  partial(nn.LeakyReLU,slope)
    elif act_name == 'tanh':
        func = nn.Tanh
    return func


class EMDPairDataset(Dataset):
    def __init__(self, sources, targets, emds):
        self.sources = sources
        self.targets = targets
        self.emds = emds

    def __len__(self):
        return len(self.emds)

    def __getitem__(self, idx):
        return self.sources[idx], self.targets[idx], self.emds[idx]

    def shuffle(self):
        permutation = np.random.permutation(len(self.emds))
        self.sources = self.sources[permutation]
        self.targets = self.sources[permutation]