import sys
sys.path.append('/data/home/ifb5104/K_server_RL')

import sys
import os
import torch
import csv
import networkx as nx
import matplotlib.pyplot as plt
from collections import Counter
from KServerEnv import KServerEnv 
from Policies.DQN_10 import DQNAgent_10

# Parameters (These should match your script's arguments and structure)
number_nodes = [9, 16, 25, 36, 49, 64, 81, 100]
graph_types = [ 'grid_gre_51', 'grid_gre_50','grid_gre_52', 'grid_gre_53', 'grid_gre_54', \
     'tree_50', 'tree_51', 'tree_52', 'tree_53', 'tree_54']
seed = 42  # Example seed

degree_distribution_results = {}

def get_degree_distribution(graph):
    degrees = [d for n, d in graph.degree()]
    degree_count = Counter(degrees)
    return degree_count

for num_nodes in number_nodes:
    for graph_type in graph_types:
        env = KServerEnv(num_nodes, num_servers=round(num_nodes / 6), batch_size=10, 
                         graph_type=graph_type, device='cpu', uniform_random=False)
        graph = env.graph

        # Determine the graph category
        if 'grid_gre' in graph_type:
            graph_category = 'Grid'
        elif 'tree' in graph_type:
            graph_category = 'Tree'
        else:
            graph_category = 'Unknown'

        degree_distribution = get_degree_distribution(graph)
        
        # Store the results in a dictionary
        degree_distribution_results[(num_nodes, graph_category)] = degree_distribution

        # Optionally, plot the degree distribution
        deg, cnt = zip(*degree_distribution.items())
        plt.bar(deg, cnt)
        plt.xlabel('Degree')
        plt.ylabel('Frequency')
        plt.title(f'Degree Distribution for {graph_category} with {num_nodes} nodes')
        plt.show()

# Optionally, write the results to a CSV file
output_file = 'degree_distribution_results.csv'
with open(output_file, 'w', newline='') as csvfile:
    writer = csv.writer(csvfile)
    writer.writerow(['Num_Nodes', 'Graph_Category', 'Degree', 'Count'])
    for key, distribution in degree_distribution_results.items():
        num_nodes, graph_category = key
        
        # Sort the distribution by Degree
        sorted_distribution = sorted(distribution.items())
        
        for degree, count in sorted_distribution:
            writer.writerow([num_nodes, graph_category, degree, count])

print(f"Degree distribution results saved to {output_file}")
