#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Feb 11 15:03:33 2020


"""

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import math

from sklearn.datasets import make_moons
from sklearn.datasets import make_blobs
from sklearn.datasets import load_svmlight_file
from sklearn.model_selection import train_test_split
from sklearn.decomposition import PCA
import torch
from torchvision import datasets
import scipy.io as sio

from mpl_toolkits.mplot3d import Axes3D


def generate_dataset_moons_blobs(
    n_pos,
    n_unl_neg,
    n_unl_pos=None,
    dim="2d",
    center_blobs=None,
    std_blobs=None,
    rotated=False,
):
    """Generate a toy dataset. The positives are drawn according to a two-
    moons shape and the negatives according to a blob.

    Parameters
    ----------
    n_pos: number of positive examples of the P dataset

    n_unl_neg: number of negative examples of the U dataset

    n_unl_pos: number of positive examples of the U dataset (default:None)
        If None, correspond to a copy of the P dataset

    dim: string (default: '2d')
        Choose between a '2d' or '3d' P dataset

    center_blobs: array
        Coordinates of the center of the blobs

    std_blobs: array
        Std of the blobs

    rotated: boolean (default:False)
        Indicates if the positives of the U dataset should be rotated
    """

    # create the Positive dataset
    seed = 76543210
    n = 0.1
    if dim == "3d":
        n = 0.05
    P, _ = make_moons(n_samples=n_pos, noise=n, shuffle=True, random_state=seed)
    P = pd.DataFrame(P, columns=["feature1", "feature2"])
    if dim == "3d":
        P["feature3"] = np.random.normal(loc=1.0, scale=0.07, size=P.shape[0])

    # create the unlabeled dataset but negative dataset
    if center_blobs is None:
        center_blobs = [[-0.5, -0.75]]
    if std_blobs is None:
        std_blobs = 0.1
    X_n, _ = make_blobs(
        n_samples=n_unl_neg,
        centers=center_blobs,
        cluster_std=std_blobs,
        random_state=seed,
    )
    X_n = pd.DataFrame(X_n, columns=["feature1", "feature2"])

    if n_unl_pos is not None:
        # create the unlabeled but positive dataset
        X_u_p, _ = make_moons(
            n_samples=n_unl_pos, noise=0.05, shuffle=True, random_state=seed
        )
        X_u_p = pd.DataFrame(X_u_p, columns=["feature1", "feature2"])
    else:
        X_u_p = P.copy()
        n_unl_pos = n_pos

    # Unlabeled dataset
    U = pd.concat([X_u_p, X_n])
    if rotated:
        angle = math.radians(-60)
        temp = U.feature1, U.feature2
        temp = (
            temp[0] * math.cos(angle) - temp[1] * math.sin(angle),
            temp[0] * math.sin(angle) + temp[1] * math.cos(angle),
        )
        U.feature1 = temp[0] - 0.5
        U.feature2 = temp[1] + 0.5

    y_unl = pd.Series(np.concatenate((np.ones(n_unl_pos), np.zeros(n_unl_neg))))

    return P, U, y_unl


def plot_dataset(P, U, y, dim="2d", ax=None, y_hat=None, transp=None):
    """Plot a dataset

    Parameters
    ----------
    P: pandas dataframe, shape=(n_p, d_p)
        Positive dataset

    U: pandas dataframe, shape=(n_u, d_u)
        Unlabeled dataset

    y: array, len=n_u
        Labels on the unlabeled dataframe. Should be 0 (negatives) or 1 (pos)

    dim: string (default: '2d')
        Choose between a '2d' or '3d' plot

    y_hat: array, len=n_u (default: None)
        Predicted labels of the unlabeled dataframe.
        If None, no labels are displayed.
    """
    if ax is None:
        fig = plt.figure()
        if dim == "3d":
            ax = fig.add_subplot(111, projection="3d")
        else:
            ax = fig.add_subplot(111)
    if dim == "3d":
        ax.scatter(
            P.feature1,
            P.feature2,
            P.feature3,
            c="k",
            marker="o",
            linewidth=1,
            label="P",
        )
    elif dim == "2d":
        ax.scatter(
            P.feature1, P.feature2, c="k", marker="o", linewidth=1, s=50, label="P"
        )
    else:
        raise ValueError("dim argument takes either '2d' or '3d' argument")
    if y_hat is None:
        pos_1 = np.where(y == 1)
        pos_0 = np.where(y == 0)
        if dim == "3d":
            ax.view_init(elev=40.0, azim=100)
            ax.scatter(U.feature1, U.feature2, zs=0, zdir="z", label="U")
        else:
            ax.scatter(
                U.iloc[pos_1].feature1,
                U.iloc[pos_1].feature2,
                s=70,
                facecolors="none",
                edgecolors="b",
                label="P+U",
            )
            ax.scatter(
                U.iloc[pos_0].feature1,
                U.iloc[pos_0].feature2,
                c="b",
                marker="+",
                linewidth=1,
                s=70,
                alpha=0.5,
                label="N+U",
            )
    else:
        if dim == "3d":
            ax.view_init(elev=40.0, azim=100)
            ax.scatter(U.feature1, U.feature2, zs=0, zdir="z", label="U", c="b")
        else:
            for i, y_i in enumerate(y):
                if y_i == 1 and y_hat[i] == 1:
                    ax.scatter(
                        U.iloc[i].feature1,
                        U.iloc[i].feature2,
                        s=70,
                        facecolors="none",
                        edgecolors="b",
                        label="P+U",
                    )
                if y_i == 0 and y_hat[i] == 0:
                    ax.scatter(
                        U.iloc[i].feature1,
                        U.iloc[i].feature2,
                        s=70,
                        c="b",
                        marker="+",
                        linewidth=1,
                        alpha=0.5,
                        label="N+U",
                    )
                if y_i == 0 and y_hat[i] == 1:
                    ax.scatter(
                        U.iloc[i].feature1,
                        U.iloc[i].feature2,
                        s=70,
                        facecolors="none",
                        edgecolors="r",
                    )
                if y_i == 1 and y_hat[i] == 0:
                    ax.scatter(
                        U.iloc[i].feature1,
                        U.iloc[i].feature2,
                        c="r",
                        marker="+",
                        linewidth=1,
                        s=70,
                        alpha=0.5,
                    )

    ax.tick_params(axis="x", which="both", bottom=False, top=False, labelbottom=False)
    ax.tick_params(axis="y", which="both", left=False, right=False, labelleft=False)
    if dim == "3d":
        ax.tick_params(axis="z", which="both", left=False, right=False, labelleft=False)
    if transp is not None:
        for i in range(transp.shape[0]):
            for j in range(transp.shape[1]):
                if transp[i, j] > 1e-5:
                    if dim == "2d":
                        ax.plot(
                            [P.iloc[i].feature1, U.iloc[j].feature1],
                            [P.iloc[i].feature2, U.iloc[j].feature2],
                            "k",
                            alpha=0.15,
                        )
                    elif dim == "3d":
                        ax.plot(
                            [P.iloc[i].feature1, U.iloc[j].feature1],
                            [P.iloc[i].feature2, U.iloc[j].feature2],
                            [P.iloc[i].feature3, 0],
                            "k",
                            alpha=0.05,
                        )


def annotate_transp_matrix(ax):
    ax.text(
        0.3,
        -0.05,
        "Unl. positives",
        horizontalalignment="center",
        verticalalignment="center",
        transform=ax.transAxes,
        c="b",
        size=20,
    )
    ax.text(
        0.85,
        -0.05,
        "Unl. negatives",
        horizontalalignment="center",
        verticalalignment="center",
        transform=ax.transAxes,
        c="b",
        size=20,
    )
    ax.text(
        -0.05,
        0.51,
        "Positives",
        horizontalalignment="center",
        verticalalignment="center",
        transform=ax.transAxes,
        rotation=90,
        size=20,
    )
    ax.text(
        -0.05,
        0.06,
        "D",
        horizontalalignment="center",
        verticalalignment="center",
        transform=ax.transAxes,
        rotation=90,
        size=20,
    )
    ax.axvline(x=5.5, color="k", linewidth=2)
    ax.axhline(y=9.5, color="k")
    ax.tick_params(axis="x", which="both", bottom=False, top=False, labelbottom=False)
    ax.tick_params(axis="y", which="both", left=False, right=False, labelleft=False)


def make_data(dataset="mnist"):
    """Load a dataset (need to be stored into the folder /data)

    Parameters
    ----------
    dataset: name of the dataset

    Returns
    -------
    np_array that contains the data

    list that contains the labels
    """

    # Piece of code for the mnist dataset
    def make_environment(images, labels, e):
        def torch_bernoulli(p, size):
            torch.manual_seed(0)
            return (torch.rand(size) < p).float()

        def torch_xor(a, b):
            return b.abs()  # Assumes both inputs are either 0 or 1

        # 2x subsample for computational convenience
        images = images.reshape((-1, 28, 28))[:, ::2, ::2]
        # Assign a binary label based on the digit; flip label with proba 0.25
        labels[labels > 1] = 0.0  # positives: class1, negatives: others
        labels = labels.float()
        # Assign a color based on the label; flip the color with probability e
        colors = torch_xor(labels, torch_bernoulli(e, len(labels)))
        # Apply the color to the image by zeroing out the other color channel
        images = torch.stack([images, images], dim=1)
        images[torch.tensor(range(len(images))), (1 - colors).long(), :, :] *= 0
        return {"images": (images.float() / 255.0), "labels": labels[:, None]}

    if dataset == "mushrooms":
        x, t = load_svmlight_file("data/mushrooms")
        x = x.toarray()
        x = np.delete(x, 77, 1)  # contains only one value
        t[t == 1] = 1
        t[t == 2] = 0
    elif dataset == "shuttle":
        x_train, t_train = load_svmlight_file("data/shuttle.scale")
        x_train = x_train.toarray()
        x_test, t_test = load_svmlight_file("data/shuttle.scale.t")
        x_test = x_test.toarray()
        x = np.concatenate([x_train, x_test])
        t = np.concatenate([t_train, t_test])
        t[~(t == 1)] = 0
    elif dataset == "pageblocks":
        data = np.loadtxt("data/page-blocks.data")
        x, t = data[:, :-1], data[:, -1]
        t[~(t == 1)] = 0
    elif dataset == "usps":
        x_train, t_train = load_svmlight_file("data/usps")
        x_train = x_train.toarray()
        x_test, t_test = load_svmlight_file("data/usps.t")
        x_test = x_test.toarray()
        x = np.concatenate([x_train, x_test])
        t = np.concatenate([t_train, t_test])
        t[t == 1] = 1
        t[t > 1] = 0
    elif dataset == "connect-4":
        x, t = load_svmlight_file("data/connect-4")
        x = x.toarray()
        t[t == -1] = 0
    elif dataset == "spambase":
        data = np.loadtxt("data/spambase.data", delimiter=",")
        x, t = data[:, :-1], data[:, -1]
    elif dataset[:5] == "mnist":
        mnist = datasets.MNIST("~/data/mnist", train=True, download=True)
        mnist = (mnist.data, mnist.targets)
        if dataset == "mnist":
            envs = [make_environment(mnist[0][::2], mnist[1][::2], 0)]
        elif dataset == "mnist_color_change_p":
            envs = [make_environment(mnist[0][::2], mnist[1][::2], 0.1)]
        elif dataset == "mnist_color_change_u":
            envs = [make_environment(mnist[0][::2], mnist[1][::2], 1)]
        data = envs[0]["images"]
        x = np.zeros((data.shape[0], 2 * 14 * 14))
        for i in range(x.shape[0]):
            x[i] = data[i].flatten()
        t = np.array(envs[0]["labels"]).flatten()
    elif dataset.startswith("surf"):
        domain = dataset[5:]
        mat = sio.loadmat("data/" + domain + "_zscore_SURF_L10.mat")
        if domain == "dslr":
            x = mat["Xs"]
            t = mat["Ys"]
        else:
            x = mat["Xt"]
            t = mat["Yt"]
        t[t != 1] = 0
        t = t.flatten()
        pca = PCA(n_components=10, random_state=0)
        pca.fit(x.T)
        x = pca.components_.T
    elif dataset.startswith("decaf"):
        domain = dataset[6:]
        mat = sio.loadmat("data/" + domain + "_decaf.mat")
        x = mat["feas"]
        t = mat["labels"]
        t[t != 1] = 0
        t = t.flatten()
        pca = PCA(n_components=40, random_state=0)
        pca.fit(x.T)
        x = pca.components_.T
    else:
        raise ValueError("Check the name of the dataset")
    return x, t


def draw_p_u_dataset_scar(dataset_p, dataset_u, size_p, size_u, prior, seed_nb=None):
    """Draw a Positive and Unlabeled dataset "at random""

    Parameters
    ----------
    dataset_p: name of the dataset among which the positives are drawn

    dataset_u: name of the dataset among which the unlabeled are drawn

    size_p: number of points in the positive dataset

    size_u: number of points in the unlabeled dataset

    prior: percentage of positives on the dataset (s)

    seed_nb: seed

    Returns
    -------
    pandas.DataFrame of shape (n_p, d_p)
        Positive dataset

    pandas.DataFrame of shape (n_u, d_u)
        Unlabeled dataset

    pandas.Series of len (n_u)
        labels of the unlabeled dataset
    """
    x, t = make_data(dataset=dataset_p)
    div = np.max(x, axis=0) - np.min(x, axis=0)
    div[div == 0] = 1
    x = (x - np.min(x, axis=0)) / div

    size_u_p = int(prior * size_u)
    size_u_n = size_u - size_u_p

    xp_t = x[t == 1]
    tp_t = t[t == 1]

    xp, xp_other, _, tp_o = train_test_split(
        xp_t, tp_t, train_size=size_p, random_state=seed_nb
    )
    if dataset_u == dataset_p:
        xup, _, _, _ = train_test_split(
            xp_other, tp_o, train_size=size_u_p, random_state=seed_nb
        )
    else:
        x, t = make_data(dataset=dataset_u)
        div = np.max(x, axis=0) - np.min(x, axis=0)
        div[div == 0] = 1
        x = (x - np.min(x, axis=0)) / div
        xp_other = x[t == 1]
        tp_o = t[t == 1]
        xup, _, _, _ = train_test_split(
            xp_other, tp_o, train_size=size_u_p, random_state=seed_nb
        )

    xn_t = x[t == 0]
    tn_t = t[t == 0]
    xun, _, _, _ = train_test_split(
        xn_t, tn_t, train_size=size_u_n, random_state=seed_nb
    )
    xu = np.concatenate([xup, xun], axis=0)
    yu = np.concatenate((np.ones(len(xup)), np.zeros(len(xun))))
    return pd.DataFrame(xp), pd.DataFrame(xu), pd.Series(yu)
