import json
import pickle
import pandas as pd
import numpy as np
import scipy.sparse as sp
from scipy.sparse import linalg
from haversine import haversine, Unit
from torch_geometric.utils import dense_to_sparse

def calculate_symmetric_normalized_laplacian(adj: np.ndarray) -> np.matrix:
    """
    Calculate the symmetric normalized Laplacian.

    The symmetric normalized Laplacian matrix is given by:
    L^{Sym} = I - D^{-1/2} A D^{-1/2}, where L is the unnormalized Laplacian, 
    D is the degree matrix, and A is the adjacency matrix.

    Args:
        adj (np.ndarray): Adjacency matrix A.

    Returns:
        np.matrix: Symmetric normalized Laplacian L^{Sym}.
    """

    adj = sp.coo_matrix(adj)
    degree = np.array(adj.sum(1)).flatten()
    degree_inv_sqrt = np.power(degree, -0.5)
    degree_inv_sqrt[np.isinf(degree_inv_sqrt)] = 0.0
    matrix_degree_inv_sqrt = sp.diags(degree_inv_sqrt)

    laplacian = sp.eye(adj.shape[0]) - matrix_degree_inv_sqrt.dot(adj).dot(matrix_degree_inv_sqrt).tocoo()
    return laplacian

def calculate_scaled_laplacian(adj: np.ndarray, lambda_max: int = 2, undirected: bool = True) -> np.matrix:
    """
    Scale the normalized Laplacian for use in Chebyshev polynomials.

    Rescale the Laplacian matrix such that its eigenvalues are within the range [-1, 1].

    Args:
        adj (np.ndarray): Adjacency matrix A.
        lambda_max (int, optional): Maximum eigenvalue, defaults to 2.
        undirected (bool, optional): If True, treats the graph as undirected, defaults to True.

    Returns:
        np.matrix: Scaled Laplacian matrix.
    """

    if undirected:
        adj = np.maximum(adj, adj.T)

    laplacian = calculate_symmetric_normalized_laplacian(adj)

    if lambda_max is None:
        lambda_max, _ = linalg.eigsh(laplacian, 1, which='LM')
        lambda_max = lambda_max[0]

    laplacian = sp.csr_matrix(laplacian)
    identity = sp.identity(laplacian.shape[0], format='csr', dtype=laplacian.dtype)

    scaled_laplacian = (2 / lambda_max) * laplacian - identity
    return scaled_laplacian

def calculate_symmetric_message_passing_adj(adj: np.ndarray) -> np.matrix:
    """
    Calculate the renormalized message-passing adjacency matrix as proposed in GCN.

    The message-passing adjacency matrix is defined as A' = D^{-1/2} (A + I) D^{-1/2}.

    Args:
        adj (np.ndarray): Adjacency matrix A.

    Returns:
        np.matrix: Renormalized message-passing adjacency matrix.
    """

    adj = adj + np.eye(adj.shape[0], dtype=np.float32)
    adj = sp.coo_matrix(adj)

    row_sum = np.array(adj.sum(1)).flatten()
    d_inv_sqrt = np.power(row_sum, -0.5)
    d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.0

    d_mat_inv_sqrt = sp.diags(d_inv_sqrt)
    mp_adj = d_mat_inv_sqrt.dot(adj).transpose().dot(d_mat_inv_sqrt).astype(np.float32)

    return mp_adj

def calculate_transition_matrix(adj: np.ndarray) -> np.matrix:
    """
    Calculate the transition matrix as proposed in DCRNN and Graph WaveNet.

    The transition matrix is defined as P = D^{-1} A, where D is the degree matrix.

    Args:
        adj (np.ndarray): Adjacency matrix A.

    Returns:
        np.matrix: Transition matrix P.
    """

    adj = sp.coo_matrix(adj)
    row_sum = np.array(adj.sum(1)).flatten()
    d_inv = np.power(row_sum, -1)
    d_inv[np.isinf(d_inv)] = 0.0

    d_mat = sp.diags(d_inv)
    prob_matrix = d_mat.dot(adj).astype(np.float32).todense()

    return prob_matrix

def get_regular_settings(dataset_name: str) -> dict:
    """
    Get the regular settings for a dataset.
    
    Args:
        dataset_name (str): Name of the dataset.
    
    Returns:
        dict: Regular settings for the dataset.
    """

    # read json file: datasets/dataset_name/desc.json
    desc = load_dataset_desc(dataset_name)
    regular_settings = desc['regular_settings']
    return regular_settings

def load_dataset_desc(dataset_name: str) -> str:
    """
    Get the description of a dataset.
    
    Args:
        dataset_name (str): Name of the dataset.
    
    Returns:
        str: Description of the dataset.
    """

    # read json file: datasets/dataset_name/desc.json
    with open(f'datasets/{dataset_name}/desc.json', 'r') as f:
        desc = json.load(f)
    return desc

def load_dataset_data(dataset_name: str) -> np.ndarray:
    """
    Load data from a .dat file (memmap) via numpy.

    Args:
        dataset_name (str): Path to the .dat file.

    Returns:
        np.ndarray: Loaded data.
    """

    shape = load_dataset_desc(dataset_name)['shape']
    dat_file_path = f'datasets/{dataset_name}/data.dat'
    data = np.memmap(dat_file_path, mode='r', dtype=np.float32, shape=tuple(shape)).copy()
    return data

def load_pkl(pickle_file: str) -> object:
    """
    Load data from a pickle file.

    Args:
        pickle_file (str): Path to the pickle file.

    Returns:
        object: Loaded object from the pickle file.
    """

    try:
        with open(pickle_file, 'rb') as f:
            pickle_data = pickle.load(f)
    except UnicodeDecodeError:
        with open(pickle_file, 'rb') as f:
            pickle_data = pickle.load(f, encoding='latin1')
    except Exception as e:
        print(f'Unable to load data from {pickle_file}: {e}')
        raise
    return pickle_data

def dump_pkl(obj: object, file_path: str):
    """
    Save an object to a pickle file.

    Args:
        obj (object): Object to save.
        file_path (str): Path to the file.
    """

    with open(file_path, 'wb') as f:
        pickle.dump(obj, f)

def load_adj(file_path: str, adj_type: str):
    """
    Load and preprocess an adjacency matrix.

    Args:
        file_path (str): Path to the file containing the adjacency matrix.
        adj_type (str): Type of adjacency matrix preprocessing. Options: 'scalap', 'normlap', 'symnadj', 'transition', 'doubletransition', 'identity', 'original'.

    Returns:
        list: List of processed adjacency matrices.
        np.ndarray: Raw adjacency matrix.
    """

    try:
        _, _, adj_mx = load_pkl(file_path)
    except ValueError:
        adj_mx = load_pkl(file_path)

    if adj_type == 'scalap':
        adj = [calculate_scaled_laplacian(adj_mx).astype(np.float32).todense()]
    elif adj_type == 'normlap':
        adj = [calculate_symmetric_normalized_laplacian(adj_mx).astype(np.float32).todense()]
    elif adj_type == 'symnadj':
        adj = [calculate_symmetric_message_passing_adj(adj_mx).astype(np.float32).todense()]
    elif adj_type == 'transition':
        adj = [calculate_transition_matrix(adj_mx).T]
    elif adj_type == 'doubletransition':
        adj = [calculate_transition_matrix(adj_mx).T, calculate_transition_matrix(adj_mx.T).T]
    elif adj_type == 'identity':
        adj = [np.diag(np.ones(adj_mx.shape[0])).astype(np.float32)]
    elif adj_type == 'original':
        adj = [adj_mx]
    else:
        raise ValueError('Undefined adjacency matrix type.')

    return adj, adj_mx
