import random
import torch
import os
from torch_geometric.data import Data
import json
from sklearn.decomposition import PCA
from sklearn.random_projection import GaussianRandomProjection
import numpy as np
import scipy.io as sio
import scipy.sparse as sp
from sklearn.metrics import roc_auc_score, average_precision_score


def test_eval(labels, probs):
    score = {}
    with torch.no_grad():
        if torch.is_tensor(labels):
            labels = labels.cpu().numpy()
        if torch.is_tensor(probs):
            probs = probs.cpu().numpy()
        score['AUROC'] = roc_auc_score(labels, probs)
        score['AUPRC'] = average_precision_score(labels, probs)
    return score


def sparse_mx_to_torch_sparse_tensor(sparse_mx):
    """Convert a scipy sparse matrix to a torch sparse tensor."""
    sparse_mx = sparse_mx.tocoo().astype(np.float32)
    indices = torch.from_numpy(
        np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
    values = torch.from_numpy(sparse_mx.data)
    shape = torch.Size(sparse_mx.shape)
    return torch.sparse.FloatTensor(indices, values, shape)


def feat_alignment(X, edges, dims):
    edge_src, edge_dst = edges
    num_edges = len(edge_src)

    if X.shape[1] < dims:
        transformer = GaussianRandomProjection(n_components=256, random_state=0)
        X = transformer.fit_transform(X.cpu().numpy())

    pca = PCA(n_components=dims,random_state=0)
    X_transformed = pca.fit_transform(X)
    X_transformed = torch.FloatTensor(X_transformed)
    return X_transformed


def preprocess_features(features):
    """Row-normalize feature matrix and convert to tuple representation"""
    rowsum = np.array(features.sum(1))
    r_inv = np.power(rowsum, -1).flatten()
    r_inv[np.isinf(r_inv)] = 0.
    r_mat_inv = sp.diags(r_inv)
    features = r_mat_inv.dot(features)
    return features.todense()


def normalize_adj(adj):
    """Symmetrically normalize adjacency matrix."""
    adj = sp.coo_matrix(adj)
    rowsum = np.array(adj.sum(1))
    d_inv_sqrt = np.power(rowsum, -0.5).flatten()
    d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.
    d_mat_inv_sqrt = sp.diags(d_inv_sqrt)
    return adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).tocoo()


class Dataset:
    def __init__(self, dims, name='cora', prefix='./dataset/'):
        # initiation
        self.shot_mask = None
        self.shot_idx = None
        self.graph = None
        self.x_list = None
        self.name = name

        preprocess_filename = f'{prefix}{name}_{dims}.npz'
        if os.path.exists(preprocess_filename):
            with np.load(preprocess_filename, allow_pickle=True) as f:
                data = f['data'].item()
                feat = torch.tensor(f['feat']).float()
        else:
            data = sio.loadmat(f"{prefix + name}.mat")
            adj = data['Network']
            feat = data['Attributes']
            adj_sp = sp.csr_matrix(adj)
            row, col = adj_sp.nonzero()
            edge_index = torch.tensor([row, col], dtype=torch.long)
            if name in ['Amazon', 'YelpChi']:
                feat = sp.lil_matrix(feat)
                feat = preprocess_features(feat)
            else:
                feat = sp.lil_matrix(feat).toarray()
            feat = torch.FloatTensor(feat)
            feat = feat_alignment(feat, edge_index, dims)
            np.savez(preprocess_filename, data=data, feat=feat)

        adj = data['Network'] if 'Network' in data else data['A']
        adj_norm = normalize_adj(adj)
        adj_norm = sparse_mx_to_torch_sparse_tensor(adj_norm)
        label = data['Label'] if ('Label' in data) else data['gnd']

        self.label = label
        self.adj_norm = adj_norm

        self.feat = feat
        feat_norm = torch.norm(feat, p=2, dim=1).mean()
        self.feat_norm = feat_norm
        ano_labels = torch.tensor(np.squeeze(np.array(self.label)), dtype=torch.float)
        # Create a PyTorch Geometric Data object
        data = Data(x=torch.tensor(self.feat, dtype=torch.float),
                    x_list=self.x_list,
                    adj=self.adj_norm,
                    ano_labels=ano_labels,
                    shot_idx=self.shot_idx,
                    shot_mask=self.shot_mask
                    )
        self.graph = data


    def few_shot(self, shot=10):
        y = self.graph.ano_labels
        num_nodes = y.shape[0]
        normal_idx = torch.where(y == 0)[0].tolist()
        random.shuffle(normal_idx)
        shot_idx = torch.tensor(normal_idx[:shot])
        shot_mask = torch.zeros(num_nodes, dtype=torch.bool)
        self.graph.shot_idx = shot_idx
        shot_mask[shot_idx] = True
        self.graph.shot_mask = shot_mask

    def propagated(self, k):
        x = torch.FloatTensor(self.feat).cuda()
        h_list = [x]
        for _ in range(k):
            h_list.append(torch.spmm(self.adj_norm.cuda(), h_list[-1]))
        self.graph.x_list = h_list

    def dict_to_list(self, d):
        n = max(d.keys()) + 1  # assumes keys are integers starting from 0
        lst = [0] * n
        for k, v in d.items():
            lst[k] = v
        return lst

    def run_wl_sparse(self, adj_matrix, num_iterations: int = 20):
        """
        Runs the Weisfeiler-Lehman algorithm on a graph represented by a sparse adjacency matrix.

        Args:
            adj_matrix (csr_matrix): A sparse adjacency matrix of the graph.
            num_iterations (int): Number of WL iterations.

        Returns:
            dict: {node: WL label as integer}
        """
        import hashlib
        # Number of nodes in the graph (size of adjacency matrix)
        num_nodes = adj_matrix.shape[0]
        adj_matrix = adj_matrix.tocsr()
        # Initialize labels to be the degree of each node (for simplicity)
        labels = {i: str(int(adj_matrix[i].sum())) for i in range(num_nodes)}

        for _ in range(num_iterations):
            new_labels = {}
            for node in range(num_nodes):
                # Get the neighbors from the sparse matrix (non-zero elements in the row)
                neighbors = adj_matrix[node].nonzero()[1]
                neighbor_labels = sorted([labels[neighbor] for neighbor in neighbors])
                combined = labels[node] + "_" + "_".join(neighbor_labels)
                hashed = hashlib.md5(combined.encode()).hexdigest()
                new_labels[node] = hashed
            labels = new_labels

        # Map string labels to integer IDs
        unique_labels = sorted(set(labels.values()))
        label_to_id = {label: idx for idx, label in enumerate(unique_labels)}
        node_label_ids = {node: label_to_id[label] for node, label in labels.items()}

        return node_label_ids

def report_difference(train_data):
    feat = torch.concat(train_data.x_list, dim=1)
    y = train_data.ano_labels
    normal_indices = torch.nonzero((y == 0)).squeeze(1).tolist()
    anomaly_indices = torch.nonzero((y == 1)).squeeze(1).tolist()
    norm_feat = feat[normal_indices, :]
    anomaly_feat = feat[anomaly_indices, :]
    nn_sim = torch.cdist(norm_feat, norm_feat).mean(dim=1)
    na_sim = torch.cdist(norm_feat, anomaly_feat).mean(dim=1)
    return nn_sim.mean().item(), na_sim.mean().item()


def measure_difference(train_data):
    feat = torch.concat(train_data.x_list, dim=1)
    nn_sim = torch.cdist(feat, feat).mean(dim=1)
    # print('For normal-normal similarity measurement, the mean similarity is', nn_sim.mean().item(), 'and median similarity is', nn_sim.median().item())
    return nn_sim.mean().item()

## normalize the data based on training sets.
def normalization(data_train, data_test, tau):
    pair_diff_train = [measure_difference(data.graph) for data in data_train]
    pair_diff_test = [measure_difference(data.graph) for data in data_test]
    dist_med_train = np.median([x for x in pair_diff_train])
    dist_med_test = np.median([x for x in pair_diff_train] + [x for x in pair_diff_test])
    for i in range(len(data_train)):
        data = data_train[i]
        norm = torch.norm(data.feat, p=2, dim=1).mean()
        data.feat = data.feat / norm
    for i in range(len(data_test)):
        data = data_test[i]
        norm = torch.norm(data.feat, p=2, dim=1).mean()
        data.feat = data.feat / norm
    pair_diff_train_normalized = [measure_difference(data.graph) for data in data_train]
    pair_diff_test_normalized = [measure_difference(data.graph) for data in data_test]
    dist_normalized_med_train = np.median([x for x in pair_diff_train_normalized])
    dist_normalized_med_test = np.median([x for x in pair_diff_train_normalized] + [x for x in pair_diff_test_normalized])
    for i in range(len(data_train)):
        data = data_train[i]
        dist_i = pair_diff_train[i]
        dist_normalized = pair_diff_train_normalized[i]
        data.feat = data.feat * np.max([np.sqrt(dist_med_train * dist_normalized /(dist_normalized_med_train * dist_i)), tau])
    for i in range(len(data_test)):
        data = data_test[i]
        dist_i = pair_diff_test[i]
        dist_normalized = pair_diff_test_normalized[i]
        data.feat = data.feat * np.max([np.sqrt(dist_med_test * dist_normalized /(dist_normalized_med_test * dist_i)), tau])
    return data_train, data_test

def read_json(model,  json_dir):
    # Construct the filename based on the dataset name and shot
    filename = f"{json_dir}/{model}.json"

    # Check if the file exists
    if os.path.exists(filename):
        # Read the JSON file and return the dictionary
        with open(filename, 'r') as file:
            try:
                data = json.load(file)
                return data
            except json.JSONDecodeError as e:
                print(f"Error decoding JSON file {filename}: {e}")
                return None
    else:
        print(f"JSON file {filename} not found.")
        return None