from abc import ABC, abstractmethod
import networkx as nx
import numpy as np
import csv
import pickle
import random
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler, StandardScaler
from scipy.linalg import expm
from torch_geometric.datasets import Planetoid
from torch_geometric.utils import to_networkx
from scipy.sparse.linalg import expm_multiply

random.seed(12)
np.random.seed(42)

class GraphLoader(ABC):
    @abstractmethod
    def get_graph_and_signals(self):
        pass

class IBBLoader(GraphLoader):
    
    def __init__(self):
        self.standard_scaler = StandardScaler()
    
    def get_graph_and_signals(self, graph_name, num_signals=168, generated_graph=None):
        """
        graph_name: ibb1 or ibb2 or ibb3 or ibb4 or ibb5 or ibb6
        num_signals: number of signals (timestamps) to return
        generated_graph: 
            If None, return original graph and signals
            Else, it means that a generated graph topology is given as argument. Thus, signals should be handled according to generated topology.
        """

        if generated_graph:
            signals_path = f"./data/ibb_subgraphs/{graph_name}/signals.npy"
            all_signals = np.load(signals_path)[:num_signals,:,0]
            all_signals = self.standard_scaler.fit_transform(all_signals)
            
            old2new = {old:new for new, old in enumerate(generated_graph.nodes())}
            new2old = {new:old for old, new in old2new.items()}

            graph = nx.relabel_nodes(generated_graph, old2new)
            signals = np.zeros((num_signals, generated_graph.number_of_nodes()))
            for new_id, old_id in new2old.items():
                signals[:, new_id] = all_signals[:, old_id]

        else:
            graph_path = f"./data/ibb_subgraphs/{graph_name}/graph.txt"
            signals_path = f"./data/ibb_subgraphs/{graph_name}/signals.npy"
            graph = nx.read_edgelist(graph_path, nodetype=int)
            signals = np.load(signals_path)[:num_signals,:,0]
            signals = self.standard_scaler.fit_transform(signals)
        return graph, signals

class PEMSLoader(GraphLoader):
    
    def __init__(self):
        self.standard_scaler = StandardScaler()

    def get_graph_and_signals(self, graph_name, num_signals=168, generated_graph=None):
        """
        graph_name: pems04 or pems08
        num_signals: number of signals (timestamps) to return
        generated_graph: 
            If None, return original graph and signals
            Else, it means that a generated graph topology is given as argument. Thus, signals should be handled according to generated topology.
        """

        if generated_graph:
            signals_path = f"./data/pems/{graph_name}/signals.npy"
            all_signals = np.load(signals_path)[:num_signals,:,0]
            all_signals = self.standard_scaler.fit_transform(all_signals)
            
            old2new = {old:new for new, old in enumerate(generated_graph.nodes())}
            new2old = {new:old for old, new in old2new.items()}

            graph = nx.relabel_nodes(generated_graph, old2new)
            signals = np.zeros((num_signals, generated_graph.number_of_nodes()))
            for new_id, old_id in new2old.items():
                signals[:, new_id] = all_signals[:, old_id]

        else:
            graph_path = f"./data/pems/{graph_name}/graph.txt"
            graph = nx.read_edgelist(graph_path, nodetype=int)

            signals_path = f"./data/pems/{graph_name}/signals.npy"
            signals = np.load(signals_path)[:num_signals,:,0]
            signals = self.standard_scaler.fit_transform(signals)
        return graph, signals
    

class CiteseerLoader(GraphLoader):
    
    def get_graph_and_signals(self, num_signals, **kwargs):
        dataset = Planetoid(root="/tmp/Cora", name="CiteSeer")
        data = dataset[0]

        graph = to_networkx(data, to_undirected=True)
        # largest connected component
        graph = graph.subgraph(max(nx.connected_components(graph), key=len)).copy()
        bfs_ordering = nx.bfs_tree(graph,1)
        new_node_labels = {n:i for i,n in enumerate(bfs_ordering)}
        graph = nx.relabel_nodes(graph, new_node_labels)
        L = nx.laplacian_matrix(graph)

        signal_type = kwargs.get("signal_type", "heat")   
        if signal_type == "gaussian":
            signals = spectral_graph_signals(L.toarray(), num_signals)
        elif signal_type == "heat":
            sample_from = kwargs.get("sample_from", "normal")
            signals = heat_diffusion(L.toarray(), num_signals=24, diffusion_rate=0.2,
                               n_channels=2, sample_from=sample_from, smooth_init=True)
        else:
            raise "Unexpected signal type!"
        return graph, signals
    
class GnutellaLoader(GraphLoader):
    
    def get_graph_and_signals(self, num_signals, **kwargs):
        with open("./data/gnutella/05.pkl", "rb") as f:
            graph = pickle.load(f)

        graph = graph.subgraph(max(nx.connected_components(graph), key=len)).copy()
        bfs_ordering = nx.bfs_tree(graph,1)
        new_node_labels = {n:i for i,n in enumerate(bfs_ordering)}
        graph = nx.relabel_nodes(graph, new_node_labels)
        L = nx.laplacian_matrix(graph)

        signal_type = kwargs.get("signal_type", "heat")   
        if signal_type == "gaussian":
            signals = spectral_graph_signals(L.toarray(), num_signals)
        elif signal_type == "heat":
            sample_from = kwargs.get("sample_from", "normal")
            signals = heat_diffusion(L.toarray(), num_signals=24, diffusion_rate=0.2,
                               n_channels=2, sample_from=sample_from, smooth_init=True)
        else:
            raise "Unexpected signal type!"
        return graph, signals
    
class SyntheticLoader(GraphLoader):

    def get_graph_and_signals(self, graph_name, size, num_signals, signal_type="gaussian", **kwargs):
        graph_path = f"./data/synthetic/{graph_name}/graph_{size}.txt"
        graph = nx.read_edgelist(graph_path, nodetype=int)
        L = nx.laplacian_matrix(graph)
        if signal_type == "gaussian":
            signals = spectral_graph_signals(L.toarray(), num_signals)
        elif signal_type == "heat":
            signals = heat_diffusion(L.toarray(), num_signals)
        else:
            raise "Unexpected signal type!"
        return graph, signals
    
def spectral_graph_signals(L, num_signals):
    D, V = np.linalg.eig(L)
    idx = D.argsort()[::-1]
    D = D[idx]
    V = V[:, idx]
    sigma = np.linalg.pinv(np.diag(D))
    mu = np.zeros(D.shape[0])
    gs_coeff = np.random.multivariate_normal(mu, sigma, num_signals)
    gs = np.dot(V, gs_coeff.T)

    gs = gs + 0.2 * np.random.randn(*gs.shape)
    return gs.T

def heat_diffusion(L, num_signals, diffusion_rate=0.1, 
                         n_channels=3, sample_from="normal", smooth_init=True):
    n = L.shape[0]
    
    # --- Initial heat vector, shape (n, n_channels) ---
    if sample_from == "normal":
        scale = 1.0 if smooth_init else 20.0
        initial_heat = np.random.normal(0, scale, size=(n, n_channels))
    elif sample_from == "discrete":
        initial_heat = np.random.choice([0, 10], size=(n, n_channels), p=[0.9, 0.1]).astype(float)
        if smooth_init:
            initial_heat += np.random.normal(0, 1, size=(n, n_channels))
        else:
            initial_heat += np.random.normal(0, 5, size=(n, n_channels))
    elif sample_from == "bimodal":
        mean1, std1 = 0, 1
        mean2, std2 = 10, 1
        samples1 = np.random.normal(mean1, std1, size=(n//2, n_channels))
        samples2 = np.random.normal(mean2, std2, size=(n - n//2, n_channels))
        initial_heat = np.vstack([samples1, samples2])
        np.random.shuffle(initial_heat)
    elif sample_from == "uniform":
        initial_heat = np.random.uniform(-10, 10, size=(n, n_channels))
    
    # --- Heat diffusion ---
    heats = np.zeros((num_signals, n, n_channels))
    heat = initial_heat.copy()
    heats[0] = heat
    
    for t in range(1, num_signals):
        # Apply diffusion independently for each channel
        for c in range(n_channels):
            heat[:, c] = expm_multiply(-diffusion_rate * L, heat[:, c])
        heat += np.random.normal(0, 0.05, size=(n, n_channels))  # small noise
        heats[t] = heat
    
    return heats

# def heat_diffusion(L, num_signals, diffusion_rate=0.1, sample_from="normal"):
#     if sample_from == "normal": # initial heat values are sampled from normal distribution
#         initial_heat = np.random.randn(L.shape[0])
#     elif sample_from == "discrete":
#         initial_heat = np.random.choice([0,100], size=L.shape[0], p=[0.8,0.2])
#     elif sample_from == "bimodal":
#         mean1,std1 = 0,0.2
#         mean2,std2 = 1,0.2
#         samples1 = np.random.normal(mean1, std1, L.shape[0]//2)
#         samples2 = np.random.normal(mean2, std2, L.shape[0]-len(samples1))
#         initial_heat = np.concatenate([samples1, samples2])
#         np.random.shuffle(initial_heat)
#     heat = initial_heat
#     heats = np.zeros((num_signals, len(heat)))
#     heats[0] = heat
#     diffusion_matrix = expm(-diffusion_rate * L)
#     for i in range(1,num_signals):
#         current_heat = diffusion_matrix @ heat
#         heats[i] = current_heat
#         heat = current_heat
#     return heats

def get_graph_data(name, num_signals, **kwargs):
    possible = ["ibb1","ibb2","ibb3","ibb4","ibb5","ibb6","pems04","pems08",
                "er_small","er_medium","er_large","ba_small","ba_medium","ba_large",
                "community_small","community_medium","community_large", "citeseer", "ego_small", "ego_medium", "ego_large", "gnutella"]
    
    if name not in possible:
        raise ValueError("Graph not found!")
    
    if name.startswith("ibb"):
        loader = IBBLoader()
        G, S = loader.get_graph_and_signals(name, num_signals)
    elif name.startswith("pems"):
        loader = PEMSLoader()
        G, S = loader.get_graph_and_signals(name, num_signals)
    elif name.startswith("citeseer"):
        loader = CiteseerLoader()
        G, S = loader.get_graph_and_signals(num_signals, **kwargs)
    elif name.startswith("gnutella"):
        loader = GnutellaLoader()
        G, S = loader.get_graph_and_signals(num_signals, **kwargs)
    else:
        loader = SyntheticLoader()
        name, size = name.split("_")
        signal_type = kwargs.get("signal_type", "heat")
        G, S = loader.get_graph_and_signals(name, size, num_signals, signal_type, **kwargs)

    return G,S
