import numpy as np
import torch 

import pandas as pd

DATASET_FEATURES_DICT = {
    'train':
        {
            'CIFAR10':'/path/of/cifar-10/feature/train',
            'CIFAR10_imbalanced':'/path/of/cifar10_imbalanced/feature/train',
            'CIFAR10_all_imbalanced':'/path/of/cifar10_all_imbalanced/feature/train',
            'optdigits': '/path/of/optdigits/feature/train',
            'phishing':'/path/of/phishing/feature/train',
            'TRPB_balanced':'/path/of/TRPB_balanced/feature/train',
        },
    'test':
        {
            'CIFAR10': '/path/of/cifar-10/feature/test',
            'CIFAR10_imbalanced':'/path/of/cifar10_imbalanced/feature/test',
            'CIFAR10_all_imbalanced':'/path/of/cifar10_all_imbalanced/feature/test',
            'optdigits': '/path/of/optdigits/feature/test',
            'phishing':'/path/of/phishing/feature/test',
            'TRPB_balanced':'/path/of/TRPB_balanced/feature/test',
        }
}

def load_features(ds_name, seed=1, train=True, normalized=True):
    " load pretrained features for a dataset "
    split = "train" if train else "test"
    fname = DATASET_FEATURES_DICT[split][ds_name].format(seed=seed)
    if fname.endswith('.npy'):
        features = np.load(fname)
    elif fname.endswith('.pth'):
        features = torch.load(fname)
    else:
        raise Exception("Unsupported filetype")
    if normalized:
        features = features / np.linalg.norm(features, axis=1, keepdims=True)
    return features

def load_features_labelexclu(ds_name, seed=1, train=True, normalized=False):
    " load pretrained features for trpb; trpb_umap "
    split = "train" if train else "test"
    fname = DATASET_FEATURES_DICT[split][ds_name].format(seed=seed)
    if fname.endswith('.npy'):
        data = np.array(np.load(fname, allow_pickle=True))
        features = data[:,1:]
        labels = data[:,0]
    elif fname.endswith('.tra'):
        df = pd.read_csv(fname, header=None)
        features = df.iloc[:, :-1].values
        labels = df.iloc[:, -1].values
    elif fname.endswith('.tes'):
        df = pd.read_csv(fname, header=None)
        features = df.iloc[:, :-1].values
        labels = df.iloc[:, -1].values
    else:
        raise Exception("Unsupported filetype")
    
    if normalized:
        # features = (features - features.mean(axis=1, keepdims=True)) / np.linalg.norm(features, axis=1, keepdims=True)
        features = (features - features.mean(axis=0, keepdims=True)) / np.linalg.norm(features, axis=0, keepdims=True)
        # features = (features - features.min(axis=0, keepdims=True)) / (features.max(axis=0, keepdims=True)-features.min(axis=0, keepdims=True))
    return features, labels

def load_features_test700(ds_name, seed=1, train=False, normalized=True):
    " load pretrained features for a dataset "
    split = "train" if train else "test"
    fname = DATASET_FEATURES_DICT[split][ds_name].format(seed=seed)
    if fname.endswith('.npy'):
        features = np.load(fname, allow_pickle=True)
    elif fname.endswith('.pth'):
        features = torch.load(fname)
    else:
        raise Exception("Unsupported filetype")
    if normalized:
        features = features / np.linalg.norm(features, axis=1, keepdims=True)
    return np.array(features)