
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import umap
from pathlib import Path
from umap.parametric_umap import ParametricUMAP
import os

def plot_2D(X, y, title, marker, ax=None):
    fig = plt.figure(figsize=(8, 8))
    part_of_subplot = ax is not None
    if not part_of_subplot:
        ax = fig.add_subplot(1, 1, 1)
    X_min = X[y == 1]
    ax.scatter(X_min[:, 0], X_min[:, 1], c='green', marker=marker, alpha=0.7, label='minority')
    X_maj = X[y == 0]
    ax.scatter(X_maj[:, 0], X_maj[:, 1], c='blue', marker=marker, alpha=0.7, label='majority')
    ax.set_ylabel("Feature #1")
    ax.set_xlabel("Feature #0")
    ax.legend(loc='upper right')
    ax.set_title(title)
    if not part_of_subplot:
        plt.savefig(f'experiments/2D/{title}.png')

def plot_loss(losses: dict, bit_map: dict, title):
    fig = plt.figure(figsize=(8, 8))
    ax = fig.add_subplot(1, 1, 1)
    ax.set_xlabel('Epoch', fontsize=15)
    ax.set_ylabel('Loss', fontsize=15)
    ax.set_title(f'{title}', fontsize=20)
    i = 0
    colors = ['r', 'b', 'g', 'k', 'c', 'm', 'y']
    for loss_name, values in losses.items():
        if bit_map[loss_name]:
            ax.scatter(range(1, len(values)+1), values, c=colors[i], label=loss_name)
            i += 1
    ax.grid()
    ax.legend()
    plt.savefig(f'experiments/{title}.png')
    plt.close()

def _plot_pca(X, y, title):
    pca_2d = PCA(n_components=2)
    pca_2d_reducer = pca_2d.fit(X)
    fig = plt.figure(figsize=(12, 6))
    targets = [0, 1, 2, 3, 4, 5]
    colors = ['g', 'r', 'm', 'b', 'k', 'y']
    #2d
    ax = fig.add_subplot(1, 1, 1)
    ax.set_xlabel('PC-1', fontsize=15)
    ax.set_ylabel('PC-2', fontsize=15)
    prin_comp = pca_2d_reducer.transform(X)
    ax.set_title(f'2D-PCA {title}', fontsize=20)
    for target, color in zip(targets, colors):
        idx = y==target
        resize = 3 if target in [4, 5] else 1
        ax.scatter(prin_comp[idx][:, 0], prin_comp[idx][:, 1], c=color, s=10*resize)
    ax.legend(['Majority', 'Minority'])
    ax.grid()
    plt.savefig(f'experiments/{title}.png')

class Visualizer(object):
    CURRENT = None
    def __init__(self, results_dir):
        self.results_dir = Path(f'{results_dir}')
        if not os.path.isdir(self.results_dir): os.makedirs(self.results_dir)
        self.pca_2d_reducer = None
        self.pca_3d_reducer = None
        self.umap_2d_reducer = None
        self.umap_3d_reducer = None

    def plot_pca(self, X, y, title):
        if not self.pca_2d_reducer:
            pca_2d = PCA(n_components=2)
            self.pca_2d_reducer = pca_2d.fit(X)
            pca_3d = PCA(n_components=3)
            self.pca_3d_reducer = pca_3d.fit(X)
        reducer_2d, reducer_3d = self.pca_2d_reducer, self.pca_3d_reducer
        fig = plt.figure(figsize=(12, 6))
        targets = [0, 1, 2, 3, 4, 5]
        colors = ['g', 'r', 'm', 'b', 'k', 'y']
        orders = [1, 2, 1, 2, 1, 1]
        #2d
        ax = fig.add_subplot(1, 2, 1)
        ax.set_xlabel('PC-1', fontsize=15)
        ax.set_ylabel('PC-2', fontsize=15)
        prin_comp = reducer_2d.transform(X)
        # ax.set_title(f'2D-PCA {title}', fontsize=20)
        for target, color, order in zip(targets, colors, orders):
            idx = y==target
            resize = 3 if target in [4, 5] else 1
            ax.scatter(prin_comp[idx][:, 0], prin_comp[idx][:, 1], c=color, s=10*resize, zorder=order)
        if 'test' in title or 'TEST' in title:
            ax.legend(['Train Maj', 'Train Min', 'Test Maj', 'Test Min'])
        else:
            ax.legend(['Train Maj', 'Train Min'])
        ax.grid()
        #3d
        ax = fig.add_subplot(1, 2, 2, projection='3d', computed_zorder=False)
        ax.set_xlabel('PC-1', fontsize=15)
        ax.set_ylabel('PC-2', fontsize=15)
        ax.set_zlabel('PC-3', fontsize=15)
        prin_comp = reducer_3d.transform(X)
        # ax.set_title(f'3D-PCA {title}', fontsize=20)
        for target, color, order in zip(targets, colors, orders):
            idx = y == target
            ax.scatter(prin_comp[idx][:, 0], prin_comp[idx][:, 1], prin_comp[idx][:, 2], c=color, s=10, zorder=order)
        if 'test' in title or 'TEST' in title:
            ax.legend(['Train Maj', 'Train Min', 'Test Maj', 'Test Min'])
        else:
            ax.legend(['Train Maj', 'Train Min'])
        ax.grid()
        plt.savefig(f'{self.results_dir}/{title}.png')
    """ EXAMPLE
    import torch as th
    from pathlib import Path
    x = th.rand(10,7)
    y = th.tensor([1,0,0,0,0,1,0,0,0,0])
    plot_pca(x,y, 'test', Path("."))
    """

    def plot_umap(self, X, y, title):
        if not self.umap_2d_reducer:
            self.umap_2d_reducer = umap.UMAP(random_state=42,
                                             n_neighbors=45,             ## default is 15. less gives worse results. 30-45 looks good
                                             n_components=2).fit(X, y)
            self.umap_3d_reducer = umap.UMAP(random_state=42,
                                             n_neighbors=45,             ## default is 15. less gives worse results. 30-45 looks good
                                             n_components=3).fit(X, y)
        reducer_2d, reducer_3d = self.umap_2d_reducer, self.umap_3d_reducer
        fig = plt.figure(figsize=(12, 6))
        targets = [0, 1, 2, 3, 4, 5]
        colors = ['g', 'r', 'm', 'b', 'k', 'y']
        orders = [1, 2, 1, 2, 1, 1]
        #2d
        ax = fig.add_subplot(1, 2, 1)
        ax.set_xlabel('UMAP-1', fontsize=15)
        ax.set_ylabel('UMAP-2', fontsize=15)
        prin_comp = reducer_2d.transform(X)
        # ax.set_title(f'2D-PCA {title}', fontsize=20)
        for target, color, order in zip(targets, colors, orders):
            idx = y==target
            resize = 3 if target in [4, 5] else 1
            ax.scatter(prin_comp[idx][:, 0], prin_comp[idx][:, 1], c=color, s=10*resize, zorder=order)
        if 'test' in title or 'TEST' in title:
            ax.legend(['Train Maj', 'Train Min', 'Test Maj', 'Test Min'])
        else:
            ax.legend(['Train Maj', 'Train Min'])
        ax.grid()
        #3d
        ax = fig.add_subplot(1, 2, 2, projection='3d', computed_zorder=False)
        ax.set_xlabel('UMAP-1', fontsize=15)
        ax.set_ylabel('UMAP-2', fontsize=15)
        ax.set_zlabel('UMAP-3', fontsize=15)
        prin_comp = reducer_3d.transform(X)
        # ax.set_title(f'3D-PCA {title}', fontsize=20)
        for target, color, order in zip(targets, colors, orders):
            idx = y == target
            ax.scatter(prin_comp[idx][:, 0], prin_comp[idx][:, 1], prin_comp[idx][:, 2], c=color, s=10, zorder=order)
        if 'test' in title or 'TEST' in title:
            ax.legend(['Train Maj', 'Train Min', 'Test Maj', 'Test Min'])
        else:
            ax.legend(['Train Maj', 'Train Min'])
        ax.grid()
        plt.savefig(f'{self.results_dir}/{title}.png')

    def plot_parametric_umap(self, X, y, title):
        if not self.umap_2d_reducer:
            self.umap_2d_reducer = umap.ParametricUMAP(random_state=42,
                                             n_neighbors=45,             ## default is 15. less gives worse results. 30-45 looks good
                                             n_components=2).fit(X, y)
            self.umap_3d_reducer = umap.ParametricUMAP(random_state=42,
                                             n_neighbors=45,             ## default is 15. less gives worse results. 30-45 looks good
                                             n_components=3).fit(X, y)
        reducer_2d, reducer_3d = self.umap_2d_reducer, self.umap_3d_reducer
        fig = plt.figure(figsize=(12, 6))
        targets = [0, 1, 2, 3, 4, 5]
        colors = ['b', 'g', 'r', 'm', 'k', 'y']  # ['g', 'r', 'm', 'b', 'k', 'y']
        orders = [1, 2, 1, 2, 1, 1]
        #2d
        ax = fig.add_subplot(1, 2, 1)
        ax.set_xlabel('UMAP-1', fontsize=15)
        ax.set_ylabel('UMAP-2', fontsize=15)
        prin_comp = reducer_2d.transform(X)
        # ax.set_title(f'2D-PCA {title}', fontsize=20)
        for target, color, order in zip(targets, colors, orders):
            idx = y==target
            resize = 3 if target in [4, 5] else 1
            ax.scatter(prin_comp[idx][:, 0], prin_comp[idx][:, 1], c=color, s=10*resize, zorder=order)
        if 'test' in title or 'TEST' in title:
            ax.legend(['Train Maj', 'Train Min', 'Test Maj', 'Test Min'])
        else:
            ax.legend(['Train Maj', 'Train Min'])
        ax.grid()
        #3d
        ax = fig.add_subplot(1, 2, 2, projection='3d', computed_zorder=False)
        ax.set_xlabel('UMAP-1', fontsize=15)
        ax.set_ylabel('UMAP-2', fontsize=15)
        ax.set_zlabel('UMAP-3', fontsize=15)
        prin_comp = reducer_3d.transform(X)
        # ax.set_title(f'3D-PCA {title}', fontsize=20)
        for target, color, order in zip(targets, colors, orders):
            idx = y == target
            ax.scatter(prin_comp[idx][:, 0], prin_comp[idx][:, 1], prin_comp[idx][:, 2], c=color, s=10, zorder=order)
        if 'test' in title or 'TEST' in title:
            ax.legend(['Train Maj', 'Train Min', 'Test Maj', 'Test Min'])
        else:
            ax.legend(['Train Maj', 'Train Min'])
        ax.grid()
        plt.savefig(f'{self.results_dir}/{title}.png')

    def visualize_oversampled(self, x_all, y_all, method_name):
        assert self.pca_2d_reducer is not None
        assert self.pca_3d_reducer is not None
        assert self.umap_2d_reducer is not None
        assert self.umap_3d_reducer is not None
        self.plot_pca(x_all, y_all, f'Oversampled_{method_name}')
        self.plot_umap(x_all, y_all, f'Oversampled_{method_name}')

    def visualize_oversampled_pca(self, x_all, y_all, method_name):
        assert self.pca_2d_reducer is not None
        assert self.pca_3d_reducer is not None
        self.plot_pca(x_all, y_all, f'Oversampled_{method_name}')

    def visualize_oversampled_umap(self, x_all, y_all, method_name):
        assert self.umap_2d_reducer is not None
        assert self.umap_3d_reducer is not None
        self.plot_umap(x_all, y_all, f'Oversampled_{method_name}')

    def visualize_oversampled_param_umap(self, x_all, y_all, method_name):
        assert self.umap_2d_reducer is not None
        assert self.umap_3d_reducer is not None
        self.plot_parametric_umap(x_all, y_all, f'Oversampled_{method_name}')

def configure(dir):
    Visualizer.CURRENT = Visualizer(dir)

def get_current():
    return Visualizer.CURRENT

def plot_pca(*args):
    get_current().plot_pca(*args)

def plot_umap(*args):
    get_current().plot_umap(*args)

def plot_param_umap(*args):
    get_current().plot_parametric_umap(*args)

def visualize_oversampled(*args):
    get_current().visualize_oversampled(*args)

def visualize_oversampled_pca(*args):
    get_current().visualize_oversampled_pca(*args)

def visualize_oversampled_umap(*args):
    get_current().visualize_oversampled_umap(*args)

def visualize_oversampled_param_umap(*args):
    get_current().visualize_oversampled_param_umap(*args)


""" EXAMPLE - UMAP
import torch as th
import numpy as np
from imblearn.over_sampling import RandomOverSampler
x_train, y_train = th.load("datasets/Keel/preprocessed/abalone9-18/abalone9-18_train.pt")
configure("experiments/results")
#get_current().plot_parametric_umap(x_train, y_train, "Original_Parametric_UMAP")
get_current().plot_umap(x_train, y_train, "Original_UMAP")

ros_model = RandomOverSampler(random_state=42)
x_all, y_all = ros_model.fit_resample(x_train, y_train)
#get_current().plot_parametric_umap(x_all, y_all, "ROS_Parametric_UMAP")
get_current().plot_umap(x_all, y_all, "ROS_UMAP")

# One at a time
x_all = th.from_numpy(x_all)
y_all = th.from_numpy(y_all)
for i in range(x_all.shape[0]):
    if i==0:
        X = reducer_2d.transform(x_all[i].reshape(1, -1))
    else:
        X = np.concatenate((X,reducer_2d.transform(x_all[i].reshape(1, -1))),0)
fig = plt.figure(figsize=(6, 6))
targets = [0, 1]
colors = ['g', 'r']
ax = fig.add_subplot(1, 1, 1)
ax.set_title(f'2D-UMAP Ros', fontsize=20)
for target, color in zip(targets, colors):
    idx = y_all == target
    ax.scatter(X[idx][:, 0], X[idx][:, 1], c=color, s=2)
ax.legend(['Majority', 'Minority'])
ax.grid()
plt.savefig(f'experiments/results/ROS.png')
"""