import numpy as np
import os
import os.path as osp
import random
import matplotlib.pyplot as plt

import torch

from pathlib import Path
from scipy.stats import gaussian_kde
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

from model.encoder import Encoder
from model.vq import VectorQuantize


def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def get_mask(num_samples: int, train_ratio: float = 0.1, test_ratio: float = 0.1):
    assert train_ratio + test_ratio < 1
    train_size = int(num_samples * train_ratio)
    test_size = int(num_samples * test_ratio)
    indices = torch.randperm(num_samples)
    return {
        'train': indices[:train_size],
        'valid': indices[train_size: test_size + train_size],
        'test': indices[test_size + train_size:]
    }


def check_path(path):
    if not osp.exists(path):
        path = Path(path)
        path.mkdir(parents=True, exist_ok=True)
    return path


def get_device(params, optimized_params=None):
    if optimized_params is None or len(optimized_params) == 0:
        device = torch.device(f"cuda:{params['device']}")
    else:
        device = torch.device(f"cuda")
    return device


def get_scheduler(optimizer, use_scheduler=True, epochs=1000):
    if use_scheduler:
        scheduler = lambda epoch: (1 + np.cos(epoch * np.pi / epochs)) * 0.5
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler)
    else:
        scheduler = None

    return scheduler


def get_device_from_model(model):
    return next(model.parameters()).device


def active_code(encoder, vq, data):
    z = encoder(data.x, data.edge_index, data.edge_attr)
    _, indices, _, _ = vq(z)
    codebook_size = vq.codebook_size
    codebook_head = vq.heads
    return indices.unique(), indices.unique().numel() / (codebook_size * codebook_head)


def visualize_tsne(embedding, label, name):
    embeddings_np = embedding.detach().cpu().numpy()
    labels_np = label.detach().cpu().numpy()
    
    X_embedded = TSNE(n_components=2).fit_transform(embeddings_np)
    
    unique_labels = np.unique(labels_np)
    colors = plt.cm.get_cmap('tab10', len(unique_labels))  # 自动匹配颜色数量

    classes = {0: 'Citation',
               1: 'Wikipedia',
               2: 'Knowledge',
               3: 'Molecular',
               4: 'E-commerce',
               5: 'GQA',
               6: 'Social'
    }

    plt.figure(figsize=(8, 6))
    
    for i, lbl in enumerate(unique_labels):
        mask = labels_np == lbl
        plt.scatter(X_embedded[mask, 0], 
                    X_embedded[mask, 1], 
                    color=colors(i),
                    label=classes[lbl],
                    s=25,
                    alpha=1
        )
    
    plt.legend(bbox_to_anchor=(1, 0), 
               loc='lower right', 
               prop={'size': 15}
    )
    plt.axis('off')
    plt.tight_layout()
    plt.savefig(f"pic/empirical/embedding.pdf", dpi=300, bbox_inches='tight')
    plt.close()


def visualize_hist(embedding):
    plt.figure(figsize=(8, 6))
    pca = PCA(n_components=1)
    emb_1d = pca.fit_transform(embedding.detach().cpu())

    bins = np.linspace(-0.25, 0.85, 35)
    hist_args = {
        'bins': bins,
        'color': '#2a5f8a',
        'edgecolor': 'w',
        'linewidth': 0.7,
        'alpha': 0.85,
        'density': False
    }
    plt.hist(emb_1d, **hist_args)

    plt.xlabel("Embedding Value", fontsize=20)
    plt.ylabel('Density', fontsize=20)
    plt.tick_params(labelsize=15)

    plt.tight_layout()
    plt.savefig("pic/empirical/codebook.pdf")
    plt.close()


def visualize_features(X_orig, X_recon):
    X_orig = X_orig.detach().cpu()
    X_recon = X_recon.detach().cpu()
    combined = np.concatenate([X_orig, X_recon], axis=0)
    pca = PCA(n_components=1)
    combined_pca = pca.fit_transform(combined)
    X_orig_pca = combined_pca[:len(X_orig)]
    X_recon_pca = combined_pca[len(X_orig):]

    plt.figure(figsize=(8, 6))

    x_vals = np.linspace(-0.5, 1.2, 1000)
    kde_zstar = gaussian_kde(X_orig_pca.reshape(-1))
    plt.fill_between(x_vals, kde_zstar(x_vals), color='blue', alpha=0.6, label='Original')
    kde_zstar = gaussian_kde(X_recon_pca.reshape(-1))
    plt.fill_between(x_vals, kde_zstar(x_vals), color='orange', alpha=0.6, label='Reconstructed')

    plt.legend(prop={'size': 15})
    plt.tick_params(labelsize=15)
    plt.xlabel("Embedding Value", fontsize=20)
    plt.ylabel("Density", fontsize=20)
    
    plt.tight_layout()
    plt.savefig("pic/empirical/features.pdf")
    plt.close()


def sample_proto_instances(labels, split, num_instances_per_class):
    y = labels.cpu().numpy()
    target_y = y[split]
    classes = np.unique(target_y)

    class_index = []
    for i in classes:
        c_i = np.where(y == i)[0]
        c_i = np.intersect1d(c_i, split)
        class_index.append(c_i)

    proto_idx = np.array([])

    for idx in class_index:
        np.random.shuffle(idx)
        proto_idx = np.concatenate((proto_idx, idx[:num_instances_per_class]))

    return proto_idx.astype(int)


def sample_proto_instances_for_graph(labels, split, num_instances_per_class):
    y = labels
    if y.ndim == 1:
        y = y.reshape(-1, 1)

    # Map class and instance indices
    if isinstance(y, torch.Tensor):
        y = y.cpu().numpy()
    target_y = y[split]
    task_list = target_y.shape[1]

    # class_index_pos = {}
    # class_index_neg = {}
    task_index_pos, task_index_neg = [], []
    for i in range(task_list):
        c_i = np.where(y[:, i] == 1)[0]
        c_i = np.intersect1d(c_i, split)
        task_index_pos.append(c_i)

        c_i = np.where(y[:, i] == 0)[0]
        c_i = np.intersect1d(c_i, split)
        task_index_neg.append(c_i)

    assert len(task_index_pos) == len(task_index_neg)

    # Randomly select instances for each task

    proto_idx, proto_labels = {}, {}
    for task, (idx_pos, idx_neg) in enumerate(zip(task_index_pos, task_index_neg)):
        tmp_proto_idx, tmp_labels = np.array([]), np.array([])

        # Randomly select instance for the task

        np.random.shuffle(idx_pos)
        np.random.shuffle(idx_neg)
        idx_pos = idx_pos[:num_instances_per_class]
        idx_neg = idx_neg[:num_instances_per_class]

        # Store the randomly selected instances

        tmp_proto_idx = np.concatenate((tmp_proto_idx, idx_pos))
        tmp_labels = np.concatenate((tmp_labels, np.ones(len(idx_pos))))
        tmp_proto_idx = np.concatenate((tmp_proto_idx, idx_neg))
        tmp_labels = np.concatenate((tmp_labels, np.zeros(len(idx_neg))))

        proto_idx[task] = tmp_proto_idx.astype(int)
        proto_labels[task] = tmp_labels.astype(int)

    return proto_idx, proto_labels


#def mask2idx(mask):
#    return torch.where(mask == True)[0]
def mask2idx(mask):
    """优化版索引转换"""
    if isinstance(mask, torch.Tensor):
        return torch.nonzero(mask, as_tuple=True)[0]
    return torch.nonzero(torch.as_tensor(mask, dtype=torch.bool), as_tuple=True)[0]


def idx2mask(idx, num_nodes):
    mask = torch.zeros(num_nodes, dtype=torch.bool)
    mask[idx] = 1
    return mask
