import random
import numpy as np

"""
For all of these functions:
Args:
    num_of_samples: How many samples we take from the dataset. By default we take the entire dataset.
    num_of_tests: How many of the samples to set aside for testing. By default we take 20% of dataset.
Returns:
    X: Features of the samples
    y: Answers to the samples
    X_test: Features of the tests
    y_test: Answers to the tests
"""

def read_higgs(num_of_samples=300000, num_of_tests=None):
    if num_of_tests is None:
        num_of_tests = num_of_samples//5
    data = []
    with open('training_data/higgs.csv') as f:
        for i in range(num_of_samples):
            data.append(list(map(float, f.readline().split(','))))
    random.shuffle(data)
    data = np.array(data)
    X, y, X_test, y_test = data[:-num_of_tests, 1:], data[:-num_of_tests, 0], data[-num_of_tests:, 1:], data[-num_of_tests:, 0]
    print("Done reading Higgs!")
    return X, y, X_test, y_test

def read_boone(num_of_samples=130065, num_of_tests=None):
    if num_of_tests is None:
        num_of_tests = num_of_samples//5
    data = []
    with open('training_data/boone.txt') as f:
        ones, zeros = map(int, f.readline().split())
        for _ in range(ones):
            data.append(list(map(float, f.readline().split())))
            data[-1].append(1)
        for _ in range(zeros):
            data.append(list(map(float, f.readline().split())))
            data[-1].append(0)
    random.shuffle(data)
    data = np.array(data[:num_of_samples])
    X, y, X_test, y_test = data[:-num_of_tests, :-1], data[:-num_of_tests, -1], data[-num_of_tests:, :-1], data[-num_of_tests:, -1]
    print("Done reading Boone!")
    return X, y, X_test, y_test

def read_forest_cover(num_of_samples=495141, num_of_tests=None):
    if num_of_tests is None:
        num_of_tests = num_of_samples//5
    from sklearn.datasets import fetch_covtype
    d = fetch_covtype()
    data = np.hstack((d.data, d.target[:, np.newaxis]-1))
    data = data[data[:, -1] <= 1]
    random.shuffle(data)
    data = data[:num_of_samples]
    X, y, X_test, y_test = data[:-num_of_tests, :-1], data[:-num_of_tests, -1], data[-num_of_tests:, :-1], data[-num_of_tests:, -1]
    print("Done reading Forest Cover!")
    return X, y, X_test, y_test

def read_diabetes(num_of_samples=768, num_of_tests=None):
    if num_of_tests is None:
        num_of_tests = num_of_samples//5
    data = []
    with open('training_data/diabetes.csv') as f:
        f.readline()
        for _ in range(num_of_samples):
            data.append(list(map(float, f.readline().split(','))))
    random.shuffle(data)
    data = np.array(data)
    X, y, X_test, y_test = data[:-num_of_tests, :-1], data[:-num_of_tests, -1], data[-num_of_tests:, :-1], data[-num_of_tests:, -1]
    print("Done reading Diabetes!")
    return X, y, X_test, y_test

"""
This one is special, since the tests will just be the entire universe.
"""
def read_adversarial(u, num_of_samples):
    X_test = np.array(np.arange(u))[:, np.newaxis]
    y_test = np.ones(u)
    X = np.random.randint(0, u, num_of_samples)[:, np.newaxis]
    y = np.ones(num_of_samples)
    return X, y, X_test, y_test