import torch
import pandas as pd
import numpy as np
from typing import Dict,  Any
import warnings
import scipy.sparse as sp

from scipy import sparse
from sklearn.decomposition import PCA
from stellargraph.core import StellarGraph
import stellargraph as sg

from src.utils import config


def normalize(edge):
    n1, n2 = edge
    if n1 > n2:
        n1, n2 = n2, n1
    return (n1, n2)

def from_flat_dict(data_dict: Dict[str, Any]):
    init_dict = {}
    del_entries = []
    features=[]
    edge_iloc=[]
    for key in data_dict.keys():
        if key.endswith('_data') or key.endswith('.data'):
            if key.endswith('_data'):
                sep = '_'
                warnings.warn(
                    "The separator used for sparse matrices during export (for .npz files) "
                    "is now '.' instead of '_'. Please update (re-save) your stored graphs.",
                    DeprecationWarning, stacklevel=2)
            else:
                sep = '.'
            matrix_name = key[:-5]
            mat_data = key
            mat_indices = '{}{}indices'.format(matrix_name, sep)
            mat_indptr = '{}{}indptr'.format(matrix_name, sep)
            mat_shape = '{}{}shape'.format(matrix_name, sep)
            if matrix_name == 'adj' or matrix_name == 'attr':
                warnings.warn(
                    "Matrices are exported (for .npz files) with full names now. "
                    "Please update (re-save) your stored graphs.",
                    DeprecationWarning, stacklevel=2)
                matrix_name += '_matrix'

            if mat_data == "adj_data":
                for i in range(data_dict[mat_shape][0]):
                    for j in range(data_dict[mat_indptr][i],data_dict[mat_indptr][i+1]):
                        edge_iloc.append((i,data_dict[mat_indices][j]))


            elif mat_data == "attr_data":
                for i in range(data_dict[mat_shape][0]):
                    feature_i=np.zeros(data_dict[mat_shape][1],np.float32)
                    feature_i[data_dict[mat_indices]
                              [data_dict[mat_indptr][i]:data_dict[mat_indptr][i + 1]]]+=\
                        data_dict[mat_data][data_dict[mat_indptr][i]:data_dict[mat_indptr][i + 1]]
                    features.append(feature_i)


            del_entries.extend([mat_data, mat_indices, mat_indptr, mat_shape])

    for del_entry in del_entries:
        del data_dict[del_entry]

    for key, val in data_dict.items():
        if ((val is not None) and (None not in val)):
            init_dict[key] = val

    node_subj = {}
    for node_id,subj in zip(init_dict["node_names"],init_dict["labels"]):
        node_subj[node_id]=init_dict["class_names"][subj]
    edge_source_ids=[]
    edge_target_ids=[]
    unique_edges = list(set(map(normalize, edge_iloc)))

    for i in range(len(unique_edges)):
        edge_target_ids.append(init_dict["node_names"][unique_edges[i][0]])
        edge_source_ids.append(init_dict["node_names"][unique_edges[i][1]])

    nodes=sg.IndexedArray(values=np.asarray(features).reshape((len(init_dict["node_names"]),-1)),
                          index=init_dict["node_names"])

    node_subjects=pd.Series(node_subj)
    edges = pd.DataFrame()
    edges['source'] = [edge for edge in edge_source_ids]
    edges['target'] = [edge for edge in edge_target_ids]
    G=StellarGraph(nodes=nodes,edges=edges)


    return G,node_subjects

def load_from_npz(file_name: str) :
    with np.load(file_name, allow_pickle=True) as loader:
        loader = dict(loader)
        G,node_subjects = from_flat_dict(loader)
    return G,node_subjects


def pca_npz(file_name, n_components=500, save_reduced=True):
    # Load original data
    with np.load(file_name, allow_pickle=True) as loader:
        data_dict = dict(loader)

        # Extract features using the same logic as in from_flat_dict
        features = []
        for key in data_dict.keys():
            if key.endswith('_data') or key.endswith('.data'):
                if key.endswith('_data'):
                    sep = '_'
                else:
                    sep = '.'
                matrix_name = key[:-5]
                mat_data = key
                mat_indices = '{}{}indices'.format(matrix_name, sep)
                mat_indptr = '{}{}indptr'.format(matrix_name, sep)
                mat_shape = '{}{}shape'.format(matrix_name, sep)

                if mat_data == "attr_data" or mat_data == "attr_matrix.data":
                    for i in range(data_dict[mat_shape][0]):
                        feature_i = np.zeros(data_dict[mat_shape][1], np.float32)
                        feature_i[data_dict[mat_indices][data_dict[mat_indptr][i]:data_dict[mat_indptr][i + 1]]] += \
                            data_dict[mat_data][data_dict[mat_indptr][i]:data_dict[mat_indptr][i + 1]]
                        features.append(feature_i)

        # Convert to numpy array
        features_array = np.array(features)

        # Apply PCA
        pca = PCA(n_components=n_components)
        reduced_features = pca.fit_transform(features_array)

        # Create a copy of the original data dictionary
        reduced_dict = data_dict.copy()

        # Convert reduced features to sparse matrix
        reduced_sparse = sp.csr_matrix(reduced_features)

        # Replace the original attribute data with reduced data
        if "attr_data" in data_dict:
            reduced_dict["attr_data"] = reduced_sparse.data
            reduced_dict["attr_indices"] = reduced_sparse.indices
            reduced_dict["attr_indptr"] = reduced_sparse.indptr
            reduced_dict["attr_shape"] = np.array([reduced_sparse.shape[0], reduced_sparse.shape[1]])
        else:
            reduced_dict["attr_matrix.data"] = reduced_sparse.data
            reduced_dict["attr_matrix.indices"] = reduced_sparse.indices
            reduced_dict["attr_matrix.indptr"] = reduced_sparse.indptr
            reduced_dict["attr_matrix.shape"] = np.array([reduced_sparse.shape[0], reduced_sparse.shape[1]])

        # Save the reduced data
        reduced_file = file_name.replace('.npz', f'_pca_{n_components}.npz')
        np.savez(reduced_file, **reduced_dict)
        print(f"Saved PCA-reduced features to {reduced_file}")


# pca_npz('../' + config.root_path + 'other_datasets/ms_academic.npz')