import bz2
import os
import shutil
from os.path import exists
from urllib.request import urlretrieve
from sklearn.datasets import load_svmlight_file


class Dataset(object):
    W8A = 'w8a'
    A9A = 'a9a'
    COLON_CANCER = 'colon-cancer'
    DUKE_BREAST_CANCER = 'duke'
    REAL_SIM = 'real-sim'


repository = 'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/'
local_dir = './dataset/'


def get_dataset(dataset):
    if not exists(local_dir):
        os.mkdir(local_dir)

    if dataset == Dataset.W8A:
        data_url = f"{repository}w8a"
        data_path = f"./dataset/{Dataset.W8A}"
        if not exists(data_path):
            _ = urlretrieve(data_url, data_path)
        data = load_svmlight_file(data_path)

    elif dataset == Dataset.A9A:
        data_url = f"{repository}a9a"
        data_path = f"./dataset/{Dataset.A9A}"
        if not exists(data_path):
            _ = urlretrieve(data_url, data_path)
        data = load_svmlight_file(data_path)

    elif dataset == Dataset.COLON_CANCER:
        data_url = f"{repository}{Dataset.COLON_CANCER}.bz2"
        data_path = f"{local_dir}{Dataset.COLON_CANCER}.bz2"
        if not exists(data_path):
            _ = urlretrieve(data_url, data_path)
        with bz2.BZ2File(data_path) as arch, open(data_path[:-4], "wb") as f:
            shutil.copyfileobj(arch, f)
        data = load_svmlight_file(data_path)

    elif dataset == Dataset.DUKE_BREAST_CANCER:
        data_url = f"{repository}{Dataset.DUKE_BREAST_CANCER}.bz2"
        data_path = f"{local_dir}{Dataset.DUKE_BREAST_CANCER}.bz2"
        if not exists(data_path):
            _ = urlretrieve(data_url, data_path)
        with bz2.BZ2File(data_path) as arch, open(data_path[:-4], "wb") as f:
            shutil.copyfileobj(arch, f)
        data = load_svmlight_file(data_path)

    elif dataset == Dataset.REAL_SIM:
        data_url = f"{repository}{Dataset.REAL_SIM}.bz2"
        data_path = f"{local_dir}{Dataset.REAL_SIM}.bz2"
        if not exists(data_path):
            _ = urlretrieve(data_url, data_path)
        with bz2.BZ2File(data_path) as arch, open(data_path[:-4], "wb") as f:
            shutil.copyfileobj(arch, f)
        data = load_svmlight_file(data_path)

    else:
        raise NotImplementedError('Dataset not supported.')

    return data[0], data[1]
