from flamby.datasets.fed_heart_disease import FedHeartDisease
import networkx as nx
import numpy as np

from gossip import RingGraph

def load_fed_heart_disease():
    n = 4
    d = 13
    total_samples = 486

    A = np.zeros((n, d, d))
    for i in range(n):
        center = FedHeartDisease(center=i, train=True)
        X = (np.array(center.features) - np.array(center.mean_of_features)[0, 0]) / (np.array(center.std_of_features)[0, 0] + 1e-8)
        A[i] = (n / total_samples) * (X.T @ X)

    A_mean = np.mean(A, axis=0)

    W = RingGraph(n)
    
    return A, W, A_mean


def load_ego_facebook(n : int):
    path = "data/facebook_combined.txt"
    G = nx.read_edgelist(path, create_using=nx.Graph(), nodetype=int)
    G = G.subgraph(range(n))
    
    # Constructing the local matrices A_i so that sum_i A_i = 2*I_n - L_norm
    A = np.zeros((n, n, n))
    for i in range(n):
        neighbors = list(G.neighbors(i))
        deg_i = len(neighbors)
        for j in neighbors:
            deg_j = G.degree[j]
            A[i, i, j] = 1/2/np.sqrt(deg_i * deg_j)
            A[i, j, i] = A[i, i, j]
        A[i, i, i] = 1
    A = n * A

    A_mean = np.mean(A, axis=0)

    # We add ~n log(n) edges to the graph to increase connectivity for the gossip matrix
    G_prime = G.copy()
    extra_edges = n * int(np.log(n))
    possible_edges = [(i, j) for i in range(n) for j in range(i+1, n) if not G_prime.has_edge(i, j)]
    np.random.shuffle(possible_edges)
    for edge in possible_edges[:extra_edges]:
        G_prime.add_edge(edge[0], edge[1])

    # Constructing the gossip matrix W using metropolis weights
    W = np.zeros((n, n))
    for i in range(n):
        neighbors = list(G_prime.neighbors(i))
        deg_i = len(neighbors)
        W[i,i] = 1
        for j in neighbors:
            deg_j = G_prime.degree[j]
            W[i,j] = 1 / (1 + max(deg_i, deg_j))
            W[i,i] -= W[i,j]

    return A, W, A_mean

