import bz2
import os
import shutil
import numpy as np
from os.path import exists
from urllib.request import urlretrieve
from sklearn.datasets import load_svmlight_file


class Dataset(object):
    A1A = 'a1a'
    A5A = 'a5a'
    A6A = 'a6a'
    A9A = 'a9a'
    W1A = 'w1a' # d = 300
    W6A = 'w6a' # d = 300
    W8A = 'w8a' # d = 300
    GISETTE = 'gisette' # d = 5,000
    MADELON = 'madelon' # d = 500
    DIABETES = 'diabetes'
    BREAST_CANCER = 'breast-cancer'
    AUSTRALIAN = 'australian'
    COLON_CANCER = 'colon-cancer'
    DUKE_BREAST_CANCER = 'duke'
    REAL_SIM = 'real-sim' # of features = 20958, # of samples equals 72,309
    DIRICHLET = 'dirichlet'


repository = 'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/'
local_dir = './dataset/'


def get_dataset(args):
    dataset = args['dataset']
    if not exists(local_dir):
        os.mkdir(local_dir)

    if dataset == Dataset.A1A or dataset == Dataset.A5A or dataset == Dataset.A6A or dataset == Dataset.A9A or dataset == Dataset.W1A or dataset == Dataset.W6A or dataset == Dataset.W8A or dataset == Dataset.DIABETES or dataset == Dataset.BREAST_CANCER or dataset == Dataset.AUSTRALIAN or dataset == Dataset.MADELON:
        data_url = f"{repository}{dataset}"
        data_path = f"./dataset/{dataset}"
        if not exists(data_path):
            _ = urlretrieve(data_url, data_path)
        data = load_svmlight_file(data_path)

    elif dataset == Dataset.COLON_CANCER:
        data_url = f"{repository}{Dataset.COLON_CANCER}.bz2"
        data_path = f"{local_dir}{Dataset.COLON_CANCER}.bz2"
        if not exists(data_path):
            _ = urlretrieve(data_url, data_path)
        with bz2.BZ2File(data_path) as arch, open(data_path[:-4], "wb") as f:
            shutil.copyfileobj(arch, f)
        data = load_svmlight_file(data_path)

    elif dataset == Dataset.DUKE_BREAST_CANCER:
        data_url = f"{repository}{Dataset.DUKE_BREAST_CANCER}.bz2"
        data_path = f"{local_dir}{Dataset.DUKE_BREAST_CANCER}.bz2"
        if not exists(data_path):
            _ = urlretrieve(data_url, data_path)
        with bz2.BZ2File(data_path) as arch, open(data_path[:-4], "wb") as f:
            shutil.copyfileobj(arch, f)
        data = load_svmlight_file(data_path)

    elif dataset == Dataset.REAL_SIM:
        data_url = f"{repository}{Dataset.REAL_SIM}.bz2"
        data_path = f"{local_dir}{Dataset.REAL_SIM}.bz2"
        if not exists(data_path):
            _ = urlretrieve(data_url, data_path)
        with bz2.BZ2File(data_path) as arch, open(data_path[:-4], "wb") as f:
            shutil.copyfileobj(arch, f)
        data = load_svmlight_file(data_path)
    elif dataset == Dataset.DIRICHLET:
        n = args['n_workers']
        d = args['dim']
        alpha = args['alpha']
        
        a = np.random.dirichlet(size=n, alpha=[alpha]*d)
        a = np.repeat(a, 2, axis=0)
        print(a)
        b = np.random.randint(0, 2, size=n)
        b = np.repeat(b, 2, axis=0)
        data = a, b

    # elif dataset == 'artificial':
    #     m = 300 # todo with args
    #     d = 200
    #     mu = 1e-1
    #     L_i = [10*mu] * args["n_workers"]
    #     data = generate_artificial_data(m=m, d=d, mu=mu, L_i=L_i)
    else:
        raise NotImplementedError('Dataset not supported.')

    return data[0], data[1]


def generate_artificial_data(m, d, mu, L_i):  # todo
    if L_i is None:
        pass
    np.random.seed(0)
    L_i = np.array(L_i)
    L_i = np.sqrt(4*L_i*m - 4*mu*m)
    n_workers = L_i.size

    # generation of A for each device
    A = []
    b = []
    for i in range(n_workers):
        D = np.random.rand(min(m, d))
        D -= np.min(D)
        D /= np.max(D)
        D = D * L_i[i]  # Diagonal matrix with maximum entry of L[i]
        S = np.zeros([m, d])
        S[:min(m, d), :min(m, d)] = np.diag(D)  # Singular values matrix
        U, _ = np.linalg.qr(np.random.rand(m, m))
        V, _ = np.linalg.qr(np.random.rand(d, d))
        A.append(U @ S @ V)  # SVD
        b.append(np.random.randint(2, size=m))

    return np.concatenate(A), np.concatenate(b)
