from deeprobust.graph.data import Dataset
import numpy as np
import pickle 
import itertools
from sklearn.metrics.pairwise import cosine_similarity

class DatasetStatistics:
    def __init__(self, data):
        self.statsDict = dict()
        self.matDict = dict()
        self.data = data
        
    def __getattribute__(self, name):
        if name in ("statsDict", "matDict"):
            return object.__getattribute__(self, name)
        elif name in self.statsDict:
            return self.statsDict[name]
        elif name in self.matDict:
            return self.matDict[name]
        else:
            return object.__getattribute__(self, name)

    def __setattr__(self, name, value):
        if name in ("statsDict", "matDict"):
            object.__setattr__(self, name, value)
        elif name in self.statsDict:
            self.statsDict[name] = value
        elif name in self.matDict:
            self.matDict[name] = value
        else:
            object.__setattr__(self, name, value)
    
    def getHomoEdgeRatio(self):
        if "homoEdgeRatio" in self.statsDict:
            return self.statsDict["homoEdgeRatio"]
        adj_coo = self.data.adj.tocoo()
        homophily_count = 0
        total_edges = 0
        for u, v in zip(adj_coo.row, adj_coo.col):
            if self.data.adj[u, v] == 0:
                continue
            if self.data.labels[u] > -1 and self.data.labels[v] > -1:
                if self.data.labels[u] == self.data.labels[v]:
                    homophily_count += 1
                total_edges += 1
        self.statsDict["homoEdgeRatio"] = float(homophily_count) / total_edges
        self.statsDict["heteEdgeCount"] = (total_edges - homophily_count) / 2
        return self.homoEdgeRatio
    
    def getHeteEdgeCount(self):
        if "heteEdgeCount" in self.statsDict:
            return self.statsDict["heteEdgeCount"]
        self.getHomoEdgeRatio()
        return self.statsDict["heteEdgeCount"]
    
    def getDegree(self):
        degrees = self.data.adj.sum(1).A1
        self.matDict["degrees"] = degrees
        return self.degrees

    def getDegreeStats(self):
        if "degrees" not in self.matDict:
            self.getDegree()
        self.statsDict["degree_dist"] = str(list(np.quantile(self.degrees, [0, 0.25, 0.5, 0.75, 1])))
        self.statsDict["degree_mean"] = self.degrees.mean()
        self.statsDict["degree_std"] = self.degrees.std()

    def crossClassSimilarity(self, mat, allow_same_node=True):
        assert (self.data.labels < 0).sum() == 0
        label_bool = self.data.onehot(self.data.labels).astype(bool)
        numClass = label_bool.shape[1]
        clsSim = np.zeros((numClass, numClass))
        clsSimStd = np.zeros((numClass, numClass))
        for i, j in itertools.product(range(numClass), repeat=2):
            if i > j: # Limit to i <= j due to symmetric
                continue
            mat_i = mat[label_bool[:, i], :]
            mat_j = mat[label_bool[:, j], :]
            
            mat_ij_sim = cosine_similarity(mat_i, mat_j)

            if i == j and not allow_same_node:
                mat_ij_sim_nodiag = mat_ij_sim.copy()
                np.fill_diagonal(mat_ij_sim_nodiag, np.nan)
                clsSim[i, j] = np.nanmean(mat_ij_sim_nodiag)
                clsSimStd[i, j] = np.nanstd(mat_ij_sim_nodiag)
            else:
                clsSim[i, j] = mat_ij_sim.mean()
                clsSimStd[i, j] = mat_ij_sim.std()
            
            clsSim[j, i] = clsSim[i, j]
            clsSimStd[j, i] = clsSimStd[i, j]
        return clsSim, clsSimStd

    def compHistogram(self):
        if not "comp_hist" in self.matDict:
            assert (self.data.labels < 0).sum() == 0
            label_onehot = self.data.onehot(self.data.labels)
            comp_count = self.data.adj @ label_onehot
            self.matDict["comp_hist"] = comp_count / comp_count.sum(1, keepdims=True)
        return self.matDict["comp_hist"]


    def __repr__(self):
        return f"DatasetStatistics({self.data.name}, {self.statsDict})"
    
    @property
    def statesInfoDict(self):
        return dict(statsDict=self.statsDict, matDict=self.matDict)
    
    def save(self, filePath):
        with open(filePath, "wb") as f:
            pickle.dump(self.statesInfoDict, f)
    
    def load(self, filePath):
        with open(filePath, "rb") as f:
            statesInfoDict = pickle.load(f)
        self.statsDict = statesInfoDict["statsDict"]
        self.matDict = statesInfoDict["matDict"]
    
    def reset(self):
        self.statsDict = dict()
        self.matDict = dict()