import numpy as np
import torch
import matplotlib.pyplot as plt
import os
from utils import savefig


def generate_data(num_datapoints, train=True, degree_of_balance=None):

    if train == False:
        assert degree_of_balance >= 0 and degree_of_balance <= 1
        x1a = np.random.uniform(low=-degree_of_balance, high=1.0, size=(int(num_datapoints/2), 1))  # [-degree_of_balance, 1]
        x1b = np.random.uniform(low=-1.0, high=degree_of_balance, size=(int(num_datapoints/2), 1))  # [-1, degree_of_balance]
        x2a = np.random.rand(int(num_datapoints/2), 1)  # [0, 1]
        x2b = np.random.rand(int(num_datapoints/2), 1) - 1  # [-1, 0]
        top_right = np.concatenate([x1a, x2a], 1)
        bottom_left = np.concatenate([x1b, x2b], 1)
        x = np.concatenate([top_right, bottom_left], 0)
    else:
        x1 = np.random.rand(num_datapoints, 1) * 2 - 1
        # Generate inputs in both cases and then combine in correlated way
        x2a = np.random.rand(num_datapoints, 1)
        x2b = np.random.rand(num_datapoints, 1) - 1
        x2 = (x1 < 0).astype(float) * x2a + (x1 >= 0).astype(float) * x2b
        x = np.concatenate([x1, x2], 1)

    np.random.shuffle(x)
    y = (x[:, 0] < 0)[:, None].astype(int)

    if train:
        print("training data size:", len(x))
    else:
        print("test data size:", len(x))
        print("degree of balance:", degree_of_balance)

    return x, y


def generate_data_ndim(num_datapoints, dims=2, train=True):
    y = np.random.binomial(1, 0.5, num_datapoints).reshape(-1, 1)
    rand_zeroone = np.random.rand(num_datapoints, dims)
    y_to_sign = 2 * y - 1
    if train:
        x = rand_zeroone * y_to_sign
    else:
        x = np.zeros_like(rand_zeroone)
        x[:, 0] = rand_zeroone[:, 0] * y_to_sign.squeeze()
        x[:, 1:] = rand_zeroone[:, 1:] * 2 - 1
    return x, y


def sample_minibatch(data, batch_size):
    x, y = data
    minibatch_idx = np.random.randint(0, x.shape[0], size=batch_size)
    return (
        torch.tensor(x[minibatch_idx]).float(),
        torch.tensor(y[minibatch_idx]).float(),
    )

def generate_and_plot_data(args):
    RGB_max = 256

    os.makedirs("figures/linear", exist_ok=True)
    SEED = 42
    torch.manual_seed(SEED)
    np.random.seed(SEED)
    training_data = generate_data(500, train=True)
    test_data = generate_data(5000, train=False, degree_of_balance=args.degree_of_balance)

    plt.figure(figsize=(4, 4))
    tr_x, tr_y = training_data
    te_x, te_y = test_data

    tr_g = tr_x[tr_y.flatten() == 0][:30]
    class_1 = plt.scatter(
        tr_g[:, 0],
        tr_g[:, 1],
        marker="s",
        zorder=10,
        s=50,
        c=[[252/256,175/256,124/256]],
        edgecolors="k",
        linewidth=1,
        alpha=0.8,
    )
    tr_g = tr_x[tr_y.flatten() == 1][:30]
    class_2 = plt.scatter(
        tr_g[:, 0],
        tr_g[:, 1],
        marker="^",
        zorder=10,
        s=60,
        c=[[135/256,201/256,195/256]],
        edgecolors="k",
        linewidth=1,
        alpha=0.8,
    )
    if True:
        te_x1 = te_x[te_y.flatten() == 0][:30]
        unlabeled_class_1 = plt.scatter(
            te_x1[:, 0], te_x1[:, 1], zorder=0, s=40, c="silver", edgecolors="k", linewidth=1, alpha=0.6
        )
        te_x2 = te_x[te_y.flatten() == 1][:30]
        unlabeled_class_2 = plt.scatter(
            te_x2[:, 0], te_x2[:, 1], zorder=0, s=40, c="silver", edgecolors="k", linewidth=1, alpha=0.6
        )
    else:
        unlabeled = plt.scatter(
            te_x[:50, 0], te_x[:50, 1], zorder=0, s=40, c="silver", edgecolors="k", linewidth=1, alpha=0.6
        )
    ax = plt.gca()
    ax.axes.xaxis.set_visible(False)
    ax.axes.yaxis.set_visible(False)
    savefig(f"linear/data")

    # legend_fig = plt.figure()
    # legend_fig.legend(
    #     [class_2, class_1, unlabeled],
    #     ["Class 1", "Class 2", "Unlabeled"],
    #     loc="center",
    #     ncol=3,
    # )
    # legend_fig.savefig("figures/linear_legend.pdf", bbox_inches="tight")

    return training_data, test_data
