import torch
import math
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from matplotlib.ticker import MaxNLocator
from relabel import neighborhood_label_distribution, nodewise_neighborhood_label_distribution
import os
import matplotlib.ticker as ticker

def index_to_mask(index, size):
    mask = torch.zeros(size, dtype=torch.bool)
    mask[index] = 1
    return mask

def random_planetoid_splits(data, num_classes, percls_trn=20, val_lb=500, seed=12134):
    index=[i for i in range(0,data.y.shape[0])]
    train_idx=[]
    rnd_state = np.random.RandomState(seed)
    for c in range(num_classes):
        class_idx = np.where(data.y.cpu() == c)[0]
        if len(class_idx)<percls_trn:
            train_idx.extend(class_idx)
        else:
            train_idx.extend(rnd_state.choice(class_idx, percls_trn,replace=False))
    rest_index = [i for i in index if i not in train_idx]
    val_idx=rnd_state.choice(rest_index,val_lb,replace=False)
    test_idx=[i for i in rest_index if i not in val_idx]
    #print(test_idx)

    data.train_mask = index_to_mask(train_idx,size=data.num_nodes)
    data.val_mask = index_to_mask(val_idx,size=data.num_nodes)
    data.test_mask = index_to_mask(test_idx,size=data.num_nodes)
    
    return data

# 클래스 색상의 보색을 계산하는 함수
def complementary_color(color):
    r, g, b = tuple(int(color.lstrip('#')[i:i+2], 16) for i in (0, 2, 4))
    return '#%02x%02x%02x' % (255 - r, 255 - g, 255 - b)

def plot_tsne(node_representation, node_labels, path, method):
    # Perform t-SNE dimensionality reduction
    tsne = TSNE(n_components=2, random_state=42)
    node_embedding_tsne = tsne.fit_transform(node_representation)

    # Plot t-SNE
    plt.figure(figsize=(10, 8))

    if method == 'vanilla':
        plt.title("GT Labels", fontsize=50)
    elif method == 'monophily_uniform':
        plt.title("PosteL Labels", fontsize=50)
    #plt.xticks(fontsize=36)
    #plt.yticks(fontsize=36)
    #plt.gca().xaxis.set_major_locator(MaxNLocator(nbins=5))
    #plt.gca().yaxis.set_major_locator(MaxNLocator(nbins=5))
    plt.gca().xaxis.set_visible(False)
    plt.gca().yaxis.set_visible(False)
    for idx, label in enumerate(np.unique(node_labels)):
        indices = np.where(node_labels == label)
        plt.scatter(node_embedding_tsne[indices, 0], node_embedding_tsne[indices, 1], label=label)
        #plt.scatter(node_embedding_tsne[indices, 0].mean(), node_embedding_tsne[indices, 1].mean(), marker='*', s=400, color='black')
    #plt.legend(fontsize=12)
    plt.tight_layout()
    plt.savefig(f'{path}.png')
    plt.savefig(f'{path}.pdf')

def plot_pca(node_representation, node_labels, path):
    tsne = PCA(n_components=2, random_state=42)
    node_embedding_tsne = tsne.fit_transform(node_representation)

    # Plot t-SNE
    plt.figure(figsize=(10, 8))
    for label in np.unique(node_labels):
        indices = np.where(node_labels == label)
        plt.scatter(node_embedding_tsne[indices, 0], node_embedding_tsne[indices, 1], label=label)
    
    #plt.legend(fontsize='large')
    plt.savefig(path)

def kl_div(p, q):
    return torch.sum(p * torch.log(p/q), dim=1)

def js_div(p, q):
    return 1/2*kl_div(p,(p+q)/2) + 1/2*kl_div(q,(p+q)/2)

def plot_hist(data, num_classes, args):
    ### cal edge-wise distribution
    train_edge_hist, _ = neighborhood_label_distribution(data, data.train_mask, data.val_mask, data.test_mask, num_classes, args.num_hop)
    all_edge_hist, _ = neighborhood_label_distribution(data, data.train_mask, data.val_mask, data.test_mask, num_classes, args.num_hop, for_train=False)

    ### cal node-wise distribution
    train_node_hist, _ = nodewise_neighborhood_label_distribution(data, data.train_mask, data.val_mask, data.test_mask, num_classes)
    all_node_hist, _ = nodewise_neighborhood_label_distribution(data, data.train_mask, data.val_mask, data.test_mask, num_classes, for_train=False)

    train_edge_hist = train_edge_hist

    x_axis = [("class"+str(i)) for i in range(num_classes)]
    for k in range(num_classes):
        # Creating the figure and axes objects
        fig, axs = plt.subplots(2, 2, figsize=(10, 5))

        # Plotting the histograms
        axs[0, 0].bar(x_axis, train_edge_hist[k].cpu())
        axs[0, 0].set_title('edge-wise defined cd with train data')

        axs[0, 1].bar(x_axis, all_edge_hist[k].cpu())
        axs[0, 1].set_title('edge-wise defined cd with all data')

        axs[1, 0].bar(x_axis, train_node_hist[k].cpu())
        axs[1, 0].set_title('node-wise defined cd with train data')

        axs[1, 1].bar(x_axis, all_node_hist[k].cpu())
        axs[1, 1].set_title('node-wise defined cd with train data')

        fig.suptitle(f'P(y_j|y_i={k})_for_{args.dataset}', fontsize=16)
        # Adjust layout to prevent overlap
        plt.tight_layout()
        plt.savefig(f'fig/cd_for_{args.dataset}_class{k}.pdf')

def plot_datawise_hist(data, num_classes, args):
    ### cal edge-wise distribution
    os.makedirs('fig/histogram/', exist_ok=True)
    train_edge_hist, _ = neighborhood_label_distribution(data, data.train_mask, data.val_mask, data.test_mask, num_classes, args.num_hop)
    all_edge_hist, _ = neighborhood_label_distribution(data, data.train_mask, data.val_mask, data.test_mask, num_classes, args.num_hop, for_train=False)

    ### cal node-wise distribution
    train_node_hist, _ = nodewise_neighborhood_label_distribution(data, data.train_mask, data.val_mask, data.test_mask, num_classes)
    all_node_hist, _ = nodewise_neighborhood_label_distribution(data, data.train_mask, data.val_mask, data.test_mask, num_classes, for_train=False)

    train_edge_hist = train_edge_hist

    x_axis = [(str(i)) for i in range(num_classes)]

    x_size = 20/3*num_classes
    fig, axs = plt.subplots(1, num_classes, figsize=(x_size, 5))
    

    #Texas
    #fig, axs = plt.subplots(1, num_classes, figsize=(33, 5))
    for k in range(num_classes):
        # Plotting the histograms
        axs[k].bar(x_axis, train_edge_hist[k].cpu(), color='#415CF2')
        #title = f'$\mathbb{P}(Y_j|Y_i={k})$'
        axs[k].set_title(fr'$\mathbb{{P}}(Y_j|Y_i={k})$', fontsize=40)
        axs[k].tick_params(axis='both',which='major',labelsize=32)

        """if k==3:
            custom_ticks = [0.2, 0.5]
            axs[k].set_yticks(custom_ticks)"""
        #axs[k].yaxis.set_major_formatter(ticker.FuncFormatter(lambda y, _: '{:.1f}'.format(y)))
    plt.tight_layout()
    plt.savefig(f'fig/histogram/cd_for_{args.dataset}.pdf')
    plt.savefig(f'fig/histogram/cd_for_{args.dataset}.png')

    return

def get_rep_norm(data, representation):
    edge_labels = data.y[data.edge_index]
    same_or_not = edge_labels[0] == edge_labels[1]
    edge_reps = representation[data.edge_index]
    rep_norm = torch.norm(edge_reps[0] - edge_reps[1], dim=1)
    same_rep_norm = rep_norm[same_or_not]
    diff_rep_norm = rep_norm[torch.logical_not(same_or_not)]

    return same_rep_norm.mean(), diff_rep_norm.mean()
