import torch
import numpy as np
import os
import pickle
import scipy.sparse as sp
import pandas as pd
from .utils import print_log, StandardScaler, vrange
from sklearn.cluster import KMeans
from sklearn.metrics.pairwise import euclidean_distances



def sym_adj(adj):
    adj = sp.coo_matrix(adj)
    rowsum = np.array(adj.sum(1))
    d_inv_sqrt = np.power(rowsum, -0.5).flatten()
    d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.
    d_mat_inv_sqrt = sp.diags(d_inv_sqrt)
    return adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).astype(np.float64).todense()

def load_pickle(pickle_file):
    try:
        with open(pickle_file, 'rb') as f:
            pickle_data = pickle.load(f)
    except UnicodeDecodeError as e:
        with open(pickle_file, 'rb') as f:
            pickle_data = pickle.load(f, encoding='latin1')
    except Exception as e:
        print('Unable to load data ', pickle_file, ':', e)
        raise
    return pickle_data


def load_adj(directory):
    if not os.path.isdir(directory):
        raise ValueError(f"{directory} is not a valid directory.")
    pkl_files = [f for f in os.listdir(directory) if f.endswith('.pkl')]

    if len(pkl_files) == 0:
        raise FileNotFoundError(f"No .pkl files found in directory: {directory}")
    elif len(pkl_files) > 1:
        raise ValueError(f"Multiple .pkl files found in directory: {directory}. Please ensure only one file exists.")


    pkl_filename = os.path.join(directory, pkl_files[0])

    try:
        _, _, adj_mx = load_pickle(pkl_filename)
    except:
        adj_mx = load_pickle(pkl_filename)
    adj = [sym_adj(adj_mx), sym_adj(np.transpose(adj_mx))]
    return adj, adj_mx


def get_dataloaders_from_index_data(
    data_dir, tod=False, dow=False, dom=False, batch_size=64, log=None
):
    data = np.load(os.path.join(data_dir, "data.npz"))["data"].astype(np.float32)

    features = [0]
    if tod:
        features.append(1)
    if dow:
        features.append(2)

    data = data[..., features]

    index = np.load(os.path.join(data_dir, "index.npz"))

    train_index = index["train"]
    val_index = index["val"]
    test_index = index["test"]

    x_train_index = vrange(train_index[:, 0], train_index[:, 1])
    y_train_index = vrange(train_index[:, 1], train_index[:, 2])
    x_val_index = vrange(val_index[:, 0], val_index[:, 1])
    y_val_index = vrange(val_index[:, 1], val_index[:, 2])
    x_test_index = vrange(test_index[:, 0], test_index[:, 1])
    y_test_index = vrange(test_index[:, 1], test_index[:, 2])

    x_train = data[x_train_index]
    y_train = data[y_train_index][..., :1]
    x_val = data[x_val_index]
    y_val = data[y_val_index][..., :1]
    x_test = data[x_test_index]
    y_test = data[y_test_index][..., :1]

    scaler = StandardScaler(mean=x_train[..., 0].mean(), std=x_train[..., 0].std())

    x_train[..., 0] = scaler.transform(x_train[..., 0])
    x_val[..., 0] = scaler.transform(x_val[..., 0])
    x_test[..., 0] = scaler.transform(x_test[..., 0])

    print_log(f"Trainset:\tx-{x_train.shape}\ty-{y_train.shape}", log=log)
    print_log(f"Valset:  \tx-{x_val.shape}  \ty-{y_val.shape}", log=log)
    print_log(f"Testset:\tx-{x_test.shape}\ty-{y_test.shape}", log=log)

    trainset = torch.utils.data.TensorDataset(
        torch.FloatTensor(x_train), torch.FloatTensor(y_train)
    )
    valset = torch.utils.data.TensorDataset(
        torch.FloatTensor(x_val), torch.FloatTensor(y_val)
    )
    testset = torch.utils.data.TensorDataset(
        torch.FloatTensor(x_test), torch.FloatTensor(y_test)
    )

    trainset_loader = torch.utils.data.DataLoader(
        trainset, batch_size=batch_size, shuffle=True, drop_last=True
    )
    valset_loader = torch.utils.data.DataLoader(
        valset, batch_size=batch_size, shuffle=False, drop_last=True
    )
    testset_loader = torch.utils.data.DataLoader(
        testset, batch_size=batch_size, shuffle=False, drop_last=True
    )

    return trainset_loader, valset_loader, testset_loader, scaler

def cluster_data(train, n_clusters):

    train = train.transpose()
    kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(train)
    labels = kmeans.labels_
    centers = kmeans.cluster_centers_
    dis = euclidean_distances(train, centers)
    _, cluster_center_dict = torch.topk(torch.Tensor(dis), 1, dim=0, largest=False)
    cluster_center_dict = cluster_center_dict.numpy()
    centers = cluster_center_dict[0]
    return labels, centers

def construct_hyper_graph(gso, labels, centers, n_clusters):
    hyper_edges = []
    adj_nodes = []
    gso = np.array(gso)
    for center in centers:
        one_order_adj_nodes = np.where((gso[center, :] != 0) & (gso[center, :] != 1))
        adj_nodes.append(one_order_adj_nodes[0])
    for i in range(n_clusters):
        hyper_edge = np.where(labels == i)[0].tolist() + np.array(adj_nodes[i]).tolist()
        hyper_edges.append(hyper_edge)
    return hyper_edges


def construct_hyper_graph_laplacian(hyper_edges, vertex_size, device):
    correlation_matrix_H = np.zeros((vertex_size, len(hyper_edges)))
    vertex_degree_matrix_D = np.zeros(vertex_size)
    edge_degree_matrix_D = np.zeros(len(hyper_edges))

    for j, hyper_edge in enumerate(hyper_edges):
        edge_degree_matrix_D[j] = len(hyper_edge)
        for node in hyper_edge:
            vertex_degree_matrix_D[node] += 1
            correlation_matrix_H[node, j] = 1

    H = np.mat(correlation_matrix_H)
    D_v_2 = np.mat(np.diag(np.power(vertex_degree_matrix_D, -0.5)))
    D_inv = np.mat(np.diag(np.power(edge_degree_matrix_D, -1)))
    W = np.mat(np.diag(np.ones(len(hyper_edges))))
    theta = D_v_2 * H * W * D_inv * H.transpose() * D_v_2
    I_m = np.eye(vertex_size)
    Nor_hyper_graph_L = I_m - theta
    return torch.Tensor(Nor_hyper_graph_L).to(device)


def construct_line_graph(hyper_edges, device):
    num_edges = len(hyper_edges)
    line_graph_adj = np.zeros((num_edges, num_edges))

    for i in range(num_edges):
        for j in range(num_edges):
            if i != j:
                intersection = len(set(hyper_edges[i]) & set(hyper_edges[j]))
                union = len(set(hyper_edges[i]) | set(hyper_edges[j]))
                if union > 0:
                    line_graph_adj[i, j] = intersection / union

    D_hat = np.sum(line_graph_adj, axis=1)

    D_hat_2 = np.diag(np.power(D_hat, -0.5))
    line_graph_L = D_hat_2 @ line_graph_adj @ D_hat_2
    return torch.Tensor(line_graph_L).to(device)


def generate_laplacians(train, gso, n_clusters, device):

    labels, centers = cluster_data(train, n_clusters)

    hyper_edges = construct_hyper_graph(gso, labels, centers, n_clusters)


    hyper_graph_L = construct_hyper_graph_laplacian(hyper_edges, train.shape[1], device)


    line_graph_L = construct_line_graph(hyper_edges, device)

    return hyper_graph_L, line_graph_L
