from enum import Enum
from sklearn import datasets
import numpy as np
from numba import jit


def get_covtype():
    covtype_data = datasets.fetch_covtype()
    return covtype_data.data

# def get_diameter(data, num_rays):
#     max_dist = 0
#     dim = data.shape[1]
#     max_diam = 0
#     for i in range(num_rays):
#         ray = np.random.randn(dim)
#         ray = ray/np.linalg.norm(ray)
#         proj_coeff = data@ray
#         min_point = np.min(proj_coeff)
#         max_point = np.max(proj_coeff)
#         if(max_point - min_point > max_diam):
#             max_diam = max_point - min_point
#     return max_diam
from tqdm import tqdm
@jit(nopython=True)
def get_diameter_pointdiff(data, num_rays):
    max_dist = 0
    dim = data.shape[1]
    max_diam = 0
    
    for i in range(num_rays):
        if i%2000 == 0:
            print(i)
        ray = np.random.randn(dim)
        ray = ray/np.linalg.norm(ray)
        proj_coeff = data@ray
        min_point = np.argmin(proj_coeff)
        max_point = np.argmax(proj_coeff)
        if np.linalg.norm(data[min_point] - data[max_point]) > max_diam:
            print("updated max dist from ", max_diam, " to ", np.linalg.norm(data[min_point] - data[max_point]))
            max_diam = np.linalg.norm(data[min_point] - data[max_point])
    return max_diam


def download_datasets():
    """download the datasets and save them to the folder data/"""
    datasets = []
    #add mnist and name it mnist
    #datasets.append((mnist.load_data()[0][0].reshape(60000, 784), "mnist"))
    
    return datasets

"""normalize the dataset by the diameter of the dataset and save to the folder data/"""
def normalize_datasets(num_rays):
    datasets = download_datasets()
    for dataset in datasets:
        print("starting normalization")
        diam = get_diameter_pointdiff(dataset[0],num_rays)
        dataset = (dataset[0]/diam, dataset[1])
        np.save("normalized_datasets/"+dataset[1]+"_normalized", dataset[0])

class Dataset(Enum):
    MUSHROOM = "mushroom"
    MNIST = "mnist"
    SKIN_NON_SKIN = "skin_nonskin"
    COVTYPE = "covtype"
    TEST_DATA = "test_data"

class Algorithm(Enum):
    K_Z_SUBSPACE = "k_z_subspace"
    KMEDIAN = "kmedian"
    KMEANS = "kmeans"
    Z34 = "z34"
    K_2_SUBSPACE= "K_Subspaces_l2"


"""load the normalized datasets from the folder normalized_datasets/"""
def load_normalized_datasets(requested_datasets=None):
    if requested_datasets is None:
        requested_datasets = [Dataset.MUSHROOM, Dataset.MNIST, Dataset.SKIN_NON_SKIN, Dataset.COVTYPE]
    datasets = []
    for dataset in requested_datasets:
        datasets.append((np.load("normalized_datasets/"+dataset.value+"_normalized.npy"), dataset.value))
    return datasets

@jit(nopython=True)
def find_diameter(dataset):
    max_dist = 0
    for p in dataset:
        for q in dataset:
            if np.linalg.norm(p-q) > max_dist:
                max_dist = np.linalg.norm(p-q)
    return max_dist

def calculate_min_distances5(data, centers):
    min_distances = np.zeros(data.shape[0])
    for i,point in enumerate(data):
        diff = point-centers
        min_distances[i] = np.min(np.sum(diff*diff, axis=1))
    min_distances = np.sqrt(min_distances)
    return min_distances