import warnings
import torch
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib import cm, colors
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.colors import Normalize
from matplotlib.cm import ScalarMappable
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from matplotlib.colors import LinearSegmentedColormap


def init_mpl():
    mpl.rcParams["font.family"] = "Times New Roman" 
    mpl.rcParams["font.size"] = 20
    mpl.rcParams["axes.titlesize"] = 20
    mpl.rcParams["axes.labelsize"] = 18
    mpl.rcParams["legend.fontsize"] = 10
    mpl.rcParams["xtick.labelsize"] = 15
    mpl.rcParams["ytick.labelsize"] = 15

def dim_reduction(embeds, labels, method="PCA", dim=2, num_samples=5000, with_indexes=False):
    from sklearn.decomposition import PCA
    from sklearn.manifold import TSNE
    from umap import UMAP
    import torch

    if isinstance(embeds, torch.Tensor):
        embeds = embeds.detach().cpu().numpy()
    # if isinstance(labels, torch.Tensor):
    #     labels = labels.detach().cpu().numpy()
    if embeds.shape[0] > num_samples:
        indexes = np.random.choice(embeds.shape[0], num_samples, replace=False)
    else:
        indexes = np.arange(embeds.shape[0])
    embeds = embeds[indexes]
    labels = labels[indexes]

    dict_method = {
        "PCA": PCA(n_components=dim),
        "TSNE": TSNE(n_components=dim, random_state=0),
        "UMAP": UMAP(n_components=dim),
    }
    reducer = dict_method[method]
    embeds = reducer.fit_transform(embeds)
    if with_indexes:
        return embeds, labels, indexes
    else:
        return embeds, labels


def label2marker(label):
    markers = ["o", "s", "D", "v", "^", "<", ">", "1", "2", "3", "4", "8", "p", "h", "H", "+", "x", "X", "D", "d", "|", "_"]
    return markers[int(label)%len(markers)]

def label2linestyle(label):
    linestyles = ["-", "--", "-.", ":"]
    return linestyles[int(label)%len(linestyles)]

def label2color(label):
    colors = ["blue", "red", "orange", "green", "purple", "yellow", "pink", 
              "#88b4ff", "#ffd180", "#2ea99d"]
    return colors[int(label)%len(colors)]

def label2cmap(label):
    max_color = label2color(label)
    max_color_rgb = colors.to_rgb(max_color)
    min_color_rgb = tuple((1 - c) * 0.05 for c in max_color_rgb)
    min_color_rgb = tuple(1 - c for c in min_color_rgb)
    pair_colors = [min_color_rgb, max_color_rgb]
    return LinearSegmentedColormap.from_list(f'cmap_{str(max_color)}', pair_colors)

def plot_2D(embeds, labels, fig_path='./example.pdf', title='', ifsave=False):
    assert embeds.shape[1] == 2, "embeds must be 2D"
    plt.figure(figsize=(5,5))
    init_mpl()
    unique_labels = np.unique(labels)
    for i, label in enumerate(unique_labels):
        mask = labels == label
        plt.scatter(embeds[mask, 0], embeds[mask, 1], 
                    label=label if isinstance(label, str) else f'label "{label}"', 
                    color=label2color(i), s=10, alpha=0.3)
    plt.legend()
    plt.title(title)
    if ifsave:
        plt.savefig(fig_path)
    else:
        plt.show()
    return plt.gcf()

def plot_2D_with_g(embeds, labels, g_tensor, fig_path='./example.pdf', title='', ifsave=False):
    assert embeds.shape[1] == 2, "embeds must be 2D"
    init_mpl()
    fig, ax = plt.subplots(figsize=(6, 5))
    unique_labels = np.unique(labels)
    label_num = len(unique_labels)
    col_num = 4 if label_num > 8 else 3 if label_num > 6 else 2 if label_num > 1 else 1
    row_num = min(label_num // col_num + (1 if label_num % col_num > 0 else 0), 4)
    row_width = 0.1

    box = ax.get_position()
    ax.set_position([box.x0, box.y0, box.width * (1-row_width*row_num), box.height])

    plt.title(title)
    for i, label in enumerate(unique_labels):
        mask = labels == label
        data = embeds[mask]
        g_values = (g_tensor[mask].detach().cpu().numpy() 
                    if isinstance(g_tensor, torch.Tensor) 
                    else g_tensor[mask])
        cmap = label2cmap(i)
        norm = Normalize(vmin=np.min(g_values), vmax=np.max(g_values))
        
        sc = ax.scatter(data[:, 0], data[:, 1], c=g_values, cmap=cmap, 
                        norm=norm, s=5, label=f'label {label}', alpha=1)

        sm = ScalarMappable(norm=norm, cmap=cmap)
        sm.set_array(g_values)
       
        axins = inset_axes(ax, width="3%", 
                           height=f"{100 / col_num - 1}%", loc='upper right',
                           bbox_to_anchor=(0.05 + i//col_num * row_width, 
                                           - (i % col_num) * (1 / col_num), 1, 1),
                           bbox_transform=ax.transAxes, 
                           borderpad=0)

        cbar = fig.colorbar(sm, cax=axins)
        cbar.set_label(f'{label}', fontsize=8)
        cbar.ax.tick_params(labelsize=6)
    if ifsave:
        plt.savefig(fig_path)
    else:
        plt.show()
    return plt.gcf()

def plot_3D_sphere(embeds, labels, fig_path='./example.pdf', title='', ifsave=False):
    assert embeds.shape[1] == 3, "embeds must be 3D"
    fig = plt.figure(figsize=(5,5))
    ax = fig.add_subplot(111, projection='3d')
    init_mpl()

    r = 1
    pi = np.pi
    cos = np.cos
    sin = np.sin
    phi, theta = np.mgrid[0.0:pi:100j, 0.0:2.0*pi:100j]
    x = r*sin(phi)*cos(theta)
    y = r*sin(phi)*sin(theta)
    z = r*cos(phi)
    ax.plot_surface(
        x, y, z,  rstride=1, cstride=1, color='w', alpha=0.3, linewidth=0)
    unique_labels = np.unique(labels)
    for i, label in enumerate(unique_labels):
        mask = labels == label
        ax.scatter(embeds[mask,0], embeds[mask,1], embeds[mask,2], 
                   label=label if isinstance(label, str) else f'label "{label}"', 
                   color=label2color(i), s=10)
    ax.legend()
    ax.set_xlim([-1, 1])
    ax.set_ylim([-1, 1])
    ax.set_zlim([-1, 1])
    ax.set_aspect("equal")
    plt.title(title)
    plt.tight_layout()
    if ifsave:
        plt.savefig(fig_path)
    else:
        plt.show()
    return plt.gcf()

def plot_line(x, y_list, labels=None, std_list=None, x_label='', y_label='', fig_path='./example.pdf', title='', ifsave=False):
    plt.figure(figsize=(6, 5))
    init_mpl()
    for i, y in enumerate(y_list):
        y = np.array(y)
        try:
            plt.plot(x, y, label=labels[i] if labels is not None else None, 
                     color=label2color(i), marker="^", linestyle=label2linestyle(i//10), linewidth=1)
        except Exception as e:
            print(e)
        if std_list is not None:
            std = np.array(std_list[i])
            plt.fill_between(x, y - std, y + std, alpha=0.2)

    plt.xlabel(x_label)
    plt.ylabel(y_label)
    plt.title(title)
    plt.legend()
    plt.grid(True)

    if ifsave:
        plt.savefig(fig_path, bbox_inches='tight')
    else:
        plt.show()
    return plt.gcf()

# ---------------------------------

if __name__ == "__main__":

    # Generate random 100-dimensional tensor data
    n_samples = 1000
    n_features = 100
    embeds = torch.randn(n_samples, n_features)

    # Generate random labels from 0-9
    labels = torch.randint(0, 10, (n_samples,))
    g = torch.randn(n_samples, 1)

    embeds, labels = dim_reduction(embeds, labels, method="PCA")
    plot_2D_with_g(embeds, labels, g, title='Random 100D Data Projected to 2D')






