

from torchvision import datasets, transforms
import torch
import numpy as np
from sklearn import datasets as sk_datasets
from torch.utils.data import TensorDataset
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import os
from v3.util.general import set_seed
import pandas as pd
import seaborn as sns


def basic_data(args, batch_sz):
    args.C = 5
    args.in_dim = 2

    num_train_samples = args.C * 10 # 10
    num_test_samples = args.C * 50

    #train_noise = 0.05 #0.05
    #xs, ys = sk_datasets.make_moons(n_samples=(num_train_samples + num_test_samples), shuffle=True,
    #                                          noise=train_noise)


    xs, ys = sk_datasets.make_blobs(n_samples=(num_train_samples + num_test_samples),
                             centers=args.C, n_features=args.in_dim, cluster_std=0.7, shuffle=True)

    x_train = xs[:num_train_samples]
    y_train = ys[:num_train_samples]

    x_test = xs[num_train_samples:]
    y_test = ys[num_train_samples:]

    # float64 in [-2.5, 2.5]
    f, ax = plt.subplots(2, figsize=(4, 2*4))
    assert len(x_train.shape) == 2 and x_train.shape[1] == args.in_dim
    for c in range(args.C):
        ax[0].scatter(x_train[:, 0][y_train == c], x_train[:, 1][y_train == c])

        ax[1].scatter(x_test[:, 0][y_test == c], x_test[:, 1][y_test == c])

    plt.tight_layout()
    f.savefig(os.path.join(args.out_dir, "data.png"), bbox_inches="tight")
    plt.close("all")

    x_train, y_train = torch.tensor(x_train, dtype=torch.float), torch.tensor(y_train)
    #x_train = normalize_2D(x_train)
    train_data_orig = TensorDataset(x_train, y_train)

    #assert (y_train == 1).logical_or(y_train == 0).all()

    #x_test, y_test = sk_datasets.make_moons(n_samples=num_test_samples, shuffle=True,
    #                                        noise=test_noise)
    x_test, y_test = torch.tensor(x_test, dtype=torch.float), torch.tensor(y_test)
    #x_test = normalize_2D(x_test)
    test_data = TensorDataset(x_test, y_test)

    #num_val = int(len(train_data_orig) * args.val_pc)
    #val_data, train_data1 = torch.utils.data.random_split(train_data_orig, [num_val, len(train_data_orig) - num_val])

    train_dl = torch.utils.data.DataLoader(train_data_orig, batch_size=batch_sz, shuffle=True)
    test_dl = torch.utils.data.DataLoader(test_data, batch_size=batch_sz, shuffle=False)

    return train_dl, test_dl



def basic_data_instance(args, batch_sz, data_instance):
    sns.set_style("dark")

    args.C = 5
    args.in_dim = 2

    num_train_samples = args.C * 10 # 10
    num_test_samples = args.C * 50

    xs, ys = sk_datasets.make_blobs(n_samples=(num_train_samples + num_test_samples),
                             centers=args.C, n_features=args.in_dim, cluster_std=0.7, shuffle=True,
                                    random_state=data_instance)

    x_train = xs[:num_train_samples]
    y_train = ys[:num_train_samples]

    x_test = xs[num_train_samples:]
    y_test = ys[num_train_samples:]

    # float64 in [-2.5, 2.5]
    f, ax = plt.subplots(1, 2, figsize=(2 * 5, 4))
    assert len(x_train.shape) == 2 and x_train.shape[1] == args.in_dim
    table_train = []
    table_test = []
    #table = []
    for c in range(args.C):
        x0 = x_train[:, 0][y_train == c]
        x1 = x_train[:, 1][y_train == c]
        for ii in range(x0.shape[0]):
            table_train.append((x0[ii], x1[ii], c, "True"))

        x0_test = x_test[:, 0][y_test == c]
        x1_test = x_test[:, 1][y_test == c]

        for ii in range(x0_test.shape[0]):
            table_test.append((x0_test[ii], x1_test[ii], c, "False"))

    # scatter in marker 1, hue determined by class
    df_train = pd.DataFrame(table_train, columns=[r"$x_0$", r"$x_1$", "Class", "Train"])
    df_test = pd.DataFrame(table_test, columns=[r"$x_0$", r"$x_1$", "Class", "Train"])

    sns.scatterplot(data=df_train, x=r"$x_0$", y=r"$x_1$", hue="Class", palette="colorblind", s=40, ax=ax[0])
    ax[0].set_title("Training data")

    sns.scatterplot(data=df_test, x=r"$x_0$", y=r"$x_1$", hue="Class", palette="colorblind", s=40, ax=ax[1])
    ax[1].set_title("Test data")

    ax[0].set_xlim(-4, 12)
    ax[1].set_xlim(-4, 12)

    ax[0].set_ylim(-5, 10)
    ax[1].set_ylim(-5, 10)


    #plt.tight_layout()
    """
    stepsize = 2
    start, end = ax.get_xlim()
    ax.xaxis.set_ticks(np.arange(start, end, stepsize))

    ax.set_ylim(-6, 10)
    start, end = ax.get_ylim()
    ax.yaxis.set_ticks(np.arange(start, end, stepsize))
    """

    #box = ax[1].get_position()
    #ax[1].set_position([box.x0, box.y0, box.width * 0.8, box.height])
    ax[1].legend(title="Class", loc="upper right")

    ax[0].legend().set_visible(False)

    f.savefig(os.path.join(args.out_dir, "data_%d.png" % data_instance), bbox_inches="tight")
    plt.close("all")

    x_train, y_train = torch.tensor(x_train, dtype=torch.float), torch.tensor(y_train)
    #x_train = normalize_2D(x_train)
    train_data_orig = TensorDataset(x_train, y_train)

    #assert (y_train == 1).logical_or(y_train == 0).all()

    #x_test, y_test = sk_datasets.make_moons(n_samples=num_test_samples, shuffle=True,
    #                                        noise=test_noise)
    x_test, y_test = torch.tensor(x_test, dtype=torch.float), torch.tensor(y_test)
    #x_test = normalize_2D(x_test)
    test_data = TensorDataset(x_test, y_test)

    #num_val = int(len(train_data_orig) * args.val_pc)
    #val_data, train_data1 = torch.utils.data.random_split(train_data_orig, [num_val, len(train_data_orig) - num_val])

    train_dl = torch.utils.data.DataLoader(train_data_orig, batch_size=batch_sz, shuffle=True)
    test_dl = torch.utils.data.DataLoader(test_data, batch_size=batch_sz, shuffle=False)

    return train_dl, test_dl