import networkx as nx
from itertools import combinations
import numpy as np
import random
from partitions import make_partitions
import argparse
from torch_geometric.sampler import EdgeSamplerInput, BaseSampler
from torch_geometric.utils import from_networkx
##SAMPLING STRATEGY
all_two_node_combinations_landbird = []
all_three_node_combinations_landbird = []
all_two_node_combinations_waterbird = []
all_three_node_combinations_waterbird = []
all_nodes = []
G = nx.read_gexf('../concept_graphs/concept_graph.gexf')
landbird_nodes, waterbird_nodes = make_partitions(G)
#Sample 10 entries based on normalized, inverted weights
#SAMPLING TYPE 1
for round in range(10):

    sampled_entries_landbird = np.random.choice(waterbird_nodes, size=10, replace=False)
    sampled_entries_waterbird = np.random.choice(landbird_nodes, size=10, replace=False)

    # Generate 2-node and 3-node permutations
    two_node_combinations_landbird = list(combinations(sampled_entries_landbird, 2))
    three_node_combinations_landbird = list(combinations(sampled_entries_landbird, 3))
    all_two_node_combinations_waterbird = list(combinations(sampled_entries_waterbird, 2))
    all_three_node_combinations_waterbird = list(combinations(sampled_entries_waterbird, 3))

    all_two_node_combinations_landbird.extend(two_node_combinations_landbird)
    all_three_node_combinations_landbird.extend(three_node_combinations_landbird)
    all_two_node_combinations_waterbird.extend(all_two_node_combinations_waterbird)
    all_three_node_combinations_waterbird.extend(all_three_node_combinations_waterbird)




#Sample 50 entries randomly from all_two_node_combinations which is a list of tuples
sampled_two_node_combinations_landbird = random.sample(all_two_node_combinations_landbird, 50)
sampled_two_node_combinations_waterbird = random.sample(all_two_node_combinations_waterbird, 50)
sampled_three_node_combinations_landbird = random.sample(all_three_node_combinations_landbird, 50)
sampled_three_node_combinations_waterbird = random.sample(all_three_node_combinations_waterbird, 50)

print(sampled_two_node_combinations_landbird)
print(sampled_two_node_combinations_waterbird)
print(sampled_three_node_combinations_landbird)
print(sampled_three_node_combinations_waterbird)
