import torch
import os
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler

def get_batch_size(n): 
    ### n = data size
    if n > 50000:
        return 1024
    elif n > 10000:
        return 512
    elif n > 5000:
        return 256
    elif n > 1000:
        return 128
    else:
        return 64
        

def load_data_ad(dataset_name, seed=42):
    data_dir = "/home/Classical/"
    np.random.seed(seed) 
    
    path = os.path.join(data_dir, dataset_name+'.npz')
    data=np.load(path)  
    
    samples = data['X']
    labels = ((data['y']).astype(int)).reshape(-1)
    
    cat_cardinality = []
    cat_cols = []
    num_cols = [i for i in range(samples.shape[1])]

    # if the dataset is too large, subsampling for considering the computational cost
    if len(labels) > 50000:
        print(f'subsampling for dataset {dataset_name}...')        
        idx_sample = np.random.choice(np.arange(len(labels)), 50000, replace=False)
        samples = samples[idx_sample]
        labels = labels[idx_sample]
    
    return samples, labels, cat_cols, cat_cardinality, num_cols
    

def split_data_ad(X, y, tasktype, num_indices=[], seed=0, device='cuda'):

    # data split for anomaly detection 
    inliers = X[y == 0]
    outliers = X[y == 1]

    num_split = len(inliers) // 2
    
    train_data = inliers[:num_split]
    train_label = np.zeros(num_split).reshape(-1, 1)

    test_data =  np.concatenate([inliers[num_split:], outliers], 0)
    test_label = np.zeros(test_data.shape[0])
    test_label[num_split:] = 1
    test_label = test_label.reshape(-1, 1)
        
    X_train = torch.from_numpy(train_data).type(torch.float32).to(device)
    X_val = torch.from_numpy(test_data).type(torch.float32).to(device)
    X_test = torch.from_numpy(test_data).type(torch.float32).to(device)

    y_train = torch.from_numpy(train_label).type(torch.float32).to(device)
    y_val = torch.from_numpy(test_label).type(torch.float32).to(device)
    y_test = torch.from_numpy(test_label).type(torch.float32).to(device)
    
    (X_train, y_train), (X_val, y_val), (X_test, y_test), y_std = prep_data_ad(X_train, X_val, X_test, y_train, y_val, y_test, num_indices=num_indices, tasktype=tasktype)
    
    return (X_train, y_train), (X_val, y_val), (X_test, y_test), y_std


def prep_data_ad(X_train, X_val, X_test, y_train, y_val, y_test, num_indices=[], tasktype='multiclass'):
    device = X_train.get_device()

    standard_transformer = StandardScaler()    
    X_train = torch.tensor(standard_transformer.fit_transform(X_train.cpu().numpy()), device=device)
    X_val = torch.tensor(standard_transformer.transform(X_val.cpu().numpy()), device=device)
    X_test = torch.tensor(standard_transformer.transform(X_test.cpu().numpy()), device=device)            
    
    return (X_train, y_train), (X_val, y_val), (X_test, y_test), 1


class TabularDataset(torch.utils.data.Dataset):
    def __init__(self, openml_id, tasktype, device, seed=1):
        if tasktype == "anomaly":
            X, y, self.X_cat, self.X_cat_cardinality, self.X_num = load_data_ad(openml_id)
            self.tasktype = tasktype            
            
            (self.X_train, self.y_train), (self.X_val, self.y_val), (self.X_test, self.y_test), self.y_std = split_data_ad(X, y, self.tasktype, num_indices=self.X_num, seed=seed, device=device)
            print("input dim: %i, cat: %i, num: %i" %(self.X_train.size(1), len(self.X_cat), len(self.X_num)))
        else:
            X, y, self.X_cat, self.X_cat_cardinality, self.X_num = load_data(openml_id)
            self.tasktype = tasktype
            
            (self.X_train, self.y_train), (self.X_val, self.y_val), (self.X_test, self.y_test), self.y_std = split_data(X, y, self.tasktype, num_indices=self.X_num, seed=seed, device=device)
            print("input dim: %i, cat: %i, num: %i" %(self.X_train.size(1), len(self.X_cat), len(self.X_num)))
        
        # import pdb; pdb.set_trace()
        self.batch_size = get_batch_size(len(self.X_train))
        
    def __len__(self, data):
        if data == "train":
            return len(self.X_train)
        elif data == "val":
            return len(self.X_val)
        else:
            return len(self.X_test)
    
    def _indv_dataset(self):
        return (self.X_train, self.y_train), (self.X_val, self.y_val), (self.X_test, self.y_test)
    
    def __getitem__(self, idx, data):
        if data == "train":
            return self.X_train[idx], self.y_train[idx]
        elif data == "val":
            return self.X_val[idx], self.y_val[idx]
        else:
            return self.X_test[idx], self.y_test[idx]