import math
import os

import numpy as np
from sklearn import svm
from torchvision.datasets.mnist import MNIST


DATA_DIR = "data"
SECOND_DIM_SHIFT = 2


def get_data_name(
    num_clients, local_var, input_d, client_samples, mu_het, label_het, data_type="gaussian"
):
    if data_type == "gaussian":
        data_name = f"clients_{num_clients}_var_{local_var}_d_{input_d}"
        data_name += f"_samples_{client_samples}_het_{mu_het}"
    elif data_type == "easy":
        data_name = f"easy_clients_{num_clients}_samples_{client_samples}_d_{input_d}"
    elif data_type == "weird":
        data_name = f"weird_{num_clients}_d_{input_d}"
    elif data_type == "hard":
        data_name = f"hard_{num_clients}_d_{input_d}"
    elif data_type == "mnist":
        data_name = f"mnist_M_{num_clients}_n_{client_samples}_h_{label_het}"
    else:
        raise NotImplementedError
    x_name = os.path.join(DATA_DIR, f"{data_name}_x.npy")
    y_name = os.path.join(DATA_DIR, f"{data_name}_y.npy")
    return x_name, y_name


def generate_data(
    num_clients, local_var, input_d, client_samples, mu_het, label_het, data_type="gaussian"
):

    local_xs = np.zeros((num_clients, client_samples, input_d))
    local_ys = None

    if data_type == "gaussian":
        assert num_clients == 2
        for client in range(num_clients):
            client_mu = [(-1) ** client * mu_het] + [SECOND_DIM_SHIFT] + [0] * (input_d-2)
            client_cov = local_var * np.eye(input_d)

    elif data_type == "easy":
        assert client_samples % 2 == 0
        for client in range(num_clients):
            margin = 2 ** (-1 * client)
            local_xs[client, :client_samples//2] = np.random.multivariate_normal(
                mean=np.zeros(input_d),
                cov=np.eye(input_d),
                size=client_samples//2,
            )
            for n in range(client_samples):
                length = np.sqrt(np.sum(local_xs[client][n] ** 2))
                local_xs[client][n] *= margin / length
            local_xs[client, client_samples//2:] = local_xs[client, :client_samples//2].copy()
            local_xs[client, client_samples//2:, 0] *= -1

    elif data_type == "weird":
        assert client_samples == 2
        assert input_d >= 2
        local_xs[0, 0, 0] = -1.0
        local_xs[0, 0, 1] = 1.0
        local_xs[0, 1, 0] = 1.0
        local_xs[0, 1, 1] = -1.0
        for client in range(1, num_clients):
            local_xs[client, 0, 0] = 1.0
            local_xs[client, 0, 1] = 1.0
            local_xs[client, 1, 0] = -1.0
            local_xs[client, 1, 1] = -1.0
        local_xs[:, :, 1] *= 2 ** 4
        local_xs[0] *= 2 ** (-4)

    elif data_type == "hard":
        assert num_clients >= 2 and num_clients <= 4
        assert input_d == num_clients
        assert client_samples == 2
        delta = 0.1
        gamma = 0.2

        if num_clients == 2:
            local_xs[0, 0, 1] = 1
            local_xs[1, 0, 1] = -1
        elif num_clients == 3:
            local_xs[0, 0, 1] = math.cos(0)
            local_xs[0, 0, 2] = math.sin(0)
            local_xs[1, 0, 1] = math.cos(2 * math.pi / 3)
            local_xs[1, 0, 2] = math.sin(2 * math.pi / 3)
            local_xs[2, 0, 1] = math.cos(4 * math.pi / 3)
            local_xs[2, 0, 2] = math.sin(4 * math.pi / 3)
        elif num_clients == 4:
            local_xs[0, 0, 1] = 1/math.sqrt(3)
            local_xs[0, 0, 2] = 1/math.sqrt(3)
            local_xs[0, 0, 3] = 1/math.sqrt(3)
            local_xs[1, 0, 1] = 1/math.sqrt(3)
            local_xs[1, 0, 2] = -1/math.sqrt(3)
            local_xs[1, 0, 3] = -1/math.sqrt(3)
            local_xs[2, 0, 1] = -1/math.sqrt(3)
            local_xs[2, 0, 2] = 1/math.sqrt(3)
            local_xs[2, 0, 3] = -1/math.sqrt(3)
            local_xs[3, 0, 1] = -1/math.sqrt(3)
            local_xs[3, 0, 2] = -1/math.sqrt(3)
            local_xs[3, 0, 3] = 1/math.sqrt(3)
        local_xs[:, 0, 0] = delta
        lens = np.sqrt(np.sum(local_xs[:, 0] ** 2, axis=1, keepdims=True))
        local_xs[:, 0] /= lens
        gamma_ratio = gamma ** (1/(num_clients - 1))
        gammas = np.array([gamma_ratio ** i for i in range(num_clients)])
        gammas = np.expand_dims(gammas, 1)
        local_xs[:, 0] *= gammas

        # Reflect data around origin. Doesn't change training at all, but now dataset
        # has two labels.
        local_xs[:, 1] = -1 * local_xs[:, 0]

    elif data_type == "mnist":
        local_xs, local_ys = mnist_partition(num_clients, client_samples, label_het)

    else:
        raise NotImplementedError

    if local_ys is None:
        label = lambda x: 1 if x >= 0 else -1
        local_ys = np.zeros((num_clients, client_samples))
        for client in range(num_clients):
            local_ys[client] = np.array([label(x[0]) for x in local_xs[client]])

    x_name, y_name = get_data_name(
        num_clients, local_var, input_d, client_samples, mu_het, label_het, data_type
    )
    if not os.path.isdir(os.path.dirname(x_name)):
        os.makedirs(os.path.dirname(x_name))
    np.save(x_name, local_xs)
    if not os.path.isdir(os.path.dirname(y_name)):
        os.makedirs(os.path.dirname(y_name))
    np.save(y_name, local_ys)


def max_margin_predictor(x, y):

    # Check for only one class, reflect data if so.
    n_classes = len(np.unique(y))
    if n_classes == 1:
        x = np.concatenate([x, -1 * x])
        y = np.concatenate([y, -1 * y])

    clf = svm.SVC(kernel="linear", C=1e7)
    clf.fit(x, y)
    w = clf.coef_.squeeze()
    w = w / np.sqrt(np.sum(w ** 2))

    preds = y * np.matmul(x, w)
    margin = float(np.min(preds))
    if margin < -1e-4:
        print(margin)
        raise ValueError("Dataset not linearly separable.")

    return w, margin


def mnist_partition(num_clients, client_samples, label_het):

    # Create dataset instance.
    data_root = os.path.join(DATA_DIR, "mnist")
    if not os.path.isdir(data_root):
        os.makedirs(data_root)
    dset = MNIST(root=data_root, train=True, download=True)

    # Sample subset of the dataset.
    dset_size = len(dset)
    mini_dset_size = num_clients * client_samples
    assert mini_dset_size <= dset_size
    sampled_idxs = np.random.choice(
        np.arange(dset_size), size=mini_dset_size, replace=False
    )
    sampled_xs = dset.data[sampled_idxs].float().numpy()
    sampled_ys = dset.targets[sampled_idxs].numpy()

    # Normalize and reshape feature vectors.
    sampled_xs = sampled_xs.reshape((mini_dset_size, -1))
    sampled_xs -= 127
    feature_lens = np.sqrt(np.sum(sampled_xs ** 2, axis=1))
    max_len = np.max(feature_lens)
    sampled_xs /= max_len

    # Separate samples into iid pool and non-iid pool.
    noniid_size = round(mini_dset_size * label_het)
    iid_size = mini_dset_size - noniid_size
    mini_idxs = np.arange(mini_dset_size)
    noniid_idxs = mini_idxs[:noniid_size]
    iid_idxs = mini_idxs[noniid_size:]

    # Sort non-iid pool by label.
    noniid_ys = sampled_ys[noniid_idxs]
    noniid_idxs = np.argsort(noniid_ys)

    # Create local datasets with samples from both pools.
    x_shape = tuple([num_clients, client_samples] + list(sampled_xs.shape[1:]))
    local_xs = np.zeros(x_shape)
    local_ys = np.zeros((num_clients, client_samples), dtype=int)
    iid_start = 0
    for m in range(num_clients):
        noniid_start = round(noniid_size * m / num_clients)
        noniid_end = round(noniid_size * (m+1) / num_clients)
        iid_samples = client_samples - (noniid_end - noniid_start)
        iid_end = iid_start + iid_samples
        current_noniid = noniid_idxs[noniid_start: noniid_end]
        current_iid = iid_idxs[iid_start: iid_end]
        client_idxs = np.concatenate([current_noniid, current_iid])
        np.random.shuffle(client_idxs)

        local_xs[m] = sampled_xs[client_idxs]
        local_ys[m] = sampled_ys[client_idxs]

        iid_start += iid_samples

    # Binarize labels.
    local_ys = 2 * (local_ys % 2) - 1

    return local_xs, local_ys
