import networkx as nx
import numpy as np
from typing import List, Tuple
import torch

def populate_matrix_dists(dists, num_gpus_per_server, num_nodes_per_leaf):
    """
    Topology actually consists of leaf switches.
    each connects several servers (num_gpus_per_server - for example, one rack, commonly 4 to 8 servers)
    and each server consists of several gpus (num_gpus_per_server - commonly 4 or 8)
    
    """
    populate = num_gpus_per_server * num_nodes_per_leaf
    populated = torch.tensor(dists).repeat_interleave(populate, dim=0).repeat_interleave(populate, dim=1)
    # populated += 2
    for leaf_area in range(dists.shape[0]):
        populated[
            leaf_area*populate:(leaf_area+1)*populate, 
            leaf_area*populate:(leaf_area+1)*populate
        ] = 2
        
    for server_area in range(dists.shape[0]*num_nodes_per_leaf):
        populated[
            server_area*num_gpus_per_server:(server_area+1)*num_gpus_per_server, 
            server_area*num_gpus_per_server:(server_area+1)*num_gpus_per_server
        ] = 0
    return populated

def prepare_network_topology(num_servers=32, topology_type='dragonfly', return_topology=False):
    """
    Prepare network topology and return distance matrix and neighbor info.
    Following the same pattern as in metaevolve/problems/moe_pack/moe.py
    """
    # Create DragonFly topology - same as in moe.py
    if topology_type == 'dragonfly':
        topology = DragonFlyGraph(num_servers)
    elif topology_type == 'dragonfly_sparse':
        topology = DragonFlyGraph(num_servers, num_diameter_links=0) # , num_diameter_links=1)
    elif topology_type == 'fat_tree':
        topology = FatTreeGraph(num_servers, share_connects_per_spine=1.0)
    elif topology_type == 'fat_tree_2_level':
        topology = FatTreeGraph(num_servers, share_connects_per_spine=0.5)
    else:
        raise ValueError(f"Invalid topology type: {topology_type}")

    # Extract distance matrix - same as in moe.py
    distance_matrix = np.zeros((num_servers, num_servers))
    for i in range(num_servers):
        for j in range(num_servers):
            distance_matrix[i][j] = topology.compute_ospf_path(i, j)[1]

    if return_topology:
        return distance_matrix, topology
    else:
        return distance_matrix

class TopologyGraph:
    def compute_ospf_path(self, source: int, destination: int) -> Tuple[List[str], int]:
        """Compute shortest path between two servers"""
        try:
            source_name = f"server_{source}"
            destination_name = f"server_{destination}"
            path = nx.shortest_path(self.graph, source_name, destination_name)
            path_length = len(path) - 1
            return path, path_length
        except nx.NetworkXNoPath:
            return [], float('inf')


class DragonFlyGraph(TopologyGraph):
    """
    DragonFly network topology implementation from metaevolve.
    A dragonfly consists of groups of routers with all-to-all connections within groups
    and sparse connections between groups.
    """
    
    def __init__(self, num_servers: int, num_diameter_links: int = None):
        """
        num_servers: number of servers in the topology
        global_link_offset: offset for global links (None means all-2-all global connect)
        """
        self.num_servers = num_servers
        self.group_size, self.num_groups = self._compute_groups(num_servers)
        self.graph = nx.Graph()
        self.server_nodes = []
        self.routers = []
        self.num_diameter_links = num_diameter_links
        self._build_topology()
        
        super().__init__()
    
    def _compute_groups(self, num_servers: int) -> Tuple[int, int]:
        """
        Compute group configuration for dragonfly topology.
        """
        # Try to balance group size and number of groups
        # Each group should have roughly sqrt(num_servers) routers
        # ideal_group_size = max(4, int(num_servers ** 0.5))
        ideal_group_size = 4
        
        # Find a good factorization
        for group_size in range(ideal_group_size, 2, -1):
            if num_servers % group_size == 0:
                num_groups = num_servers // group_size
                if num_groups >= 2:
                    return group_size, num_groups
        
        # Fallback
        return 4, max(1, num_servers // 4)
    
    def _build_topology(self):
        """Build the dragonfly topology."""
        group_size, num_groups = self.group_size, self.num_groups
        
        # Create servers and routers
        server_id = 0
        for group in range(num_groups):
            for router in range(group_size):
                if server_id < self.num_servers:
                    # Create server
                    server_name = f"server_{server_id}"
                    self.graph.add_node(server_name, type='server', 
                                      group=group, router=router, server_id=server_id)
                    self.server_nodes.append(server_name)
                    
                    # Create router
                    router_name = f"router_{group}_{router}"
                    self.graph.add_node(router_name, type='router', group=group, router=router)
                    self.routers.append(router_name)
                    
                    # Connect server to router
                    self.graph.add_edge(server_name, router_name)
                    
                    server_id += 1
        
        # Create all-to-all connections within each group
        for group in range(num_groups):
            group_routers = [f"router_{group}_{r}" for r in range(group_size) 
                           if f"router_{group}_{r}" in self.routers]
            
            for i, router1 in enumerate(group_routers):
                for j, router2 in enumerate(group_routers):
                    if i != j and not self.graph.has_edge(router1, router2):
                        self.graph.add_edge(router1, router2)
        
        if self.num_diameter_links is None:
            # Create inter-group connections (dense)
            # Each router connects to one router in each other group
            for group1 in range(num_groups):
                for group2 in range(group1 + 1, num_groups):
                    # Connect router 0 from group1 to router 0 from group2
                    router1 = f"router_{group1}_0"
                    router2 = f"router_{group2}_0"
                    if router1 in self.routers and router2 in self.routers:
                        self.graph.add_edge(router1, router2)
        else:
            # Create inter-group connections (sparse)
            # Each router connects to one router in each other group
            for group1 in range(num_groups):
                offset = max(1, num_groups//(self.num_diameter_links+1))
                
                if self.num_diameter_links > 0:
                    next_connects = [group1 + offset*i for i in range(1, self.num_diameter_links+1)]
                else:
                    next_connects = []
                next_connects.append((num_groups + group1 - 1) % num_groups) # prev
                next_connects.append((group1 + 1) % num_groups) # next
                next_connects = list(set(next_connects)) # remove duplicates
                
                for group2 in next_connects:
                    # Connect router 0 from group1 to router 0 from group2
                    router1 = f"router_{group1}_0"
                    router2 = f"router_{group2}_0"
                    if router1 in self.routers and router2 in self.routers:
                        self.graph.add_edge(router1, router2)


class FatTreeGraph(TopologyGraph):
    def __init__(self, num_servers: int, num_diameter_links: int = None, share_connects_per_spine: float = 1.0):
        """
        num_servers: number of servers in the topology
        global_link_offset: offset for global links (None means all-2-all global connect)
        share_connects_per_spine: simulate two-level hierarchy (if 1.0 - only one level)
        """
        self.num_servers = num_servers
        self.group_size, self.num_groups = self._compute_groups(num_servers)
        self.graph = nx.Graph()
        self.server_nodes = []
        self.routers = []
        self.num_diameter_links = num_diameter_links
        self.share_connects_per_spine = share_connects_per_spine
        self.server_connects_per_spine = max(1, int(share_connects_per_spine * self.num_groups))
        
        self._build_topology()

        super().__init__()
    
    def _compute_groups(self, num_servers: int) -> Tuple[int, int]:
        """
        Compute group configuration for dragonfly topology.
        """
        # Try to balance group size and number of groups
        # Each group should have roughly sqrt(num_servers) routers
        # ideal_group_size = max(4, int(num_servers ** 0.5))
        ideal_group_size = 4
        
        # Find a good factorization
        for group_size in range(ideal_group_size, 2, -1):
            if num_servers % group_size == 0:
                num_groups = num_servers // group_size
                if num_groups >= 2:
                    return group_size, num_groups
        
        # Fallback
        return 4, max(1, num_servers // 4)
    
    def _build_topology(self):
        """Build the fat tree topology."""
        group_size, num_groups = self.group_size, self.num_groups
        
        # Create servers and routers
        server_id = 0
        for group in range(num_groups):
            for router in range(group_size):
                if server_id < self.num_servers:
                    # Create server
                    server_name = f"server_{server_id}"
                    self.graph.add_node(server_name, type='server', 
                                      group=group, router=router, server_id=server_id)
                    self.server_nodes.append(server_name)
                    
                    # Create router
                    router_name = f"edge_router_{group}_{router}"
                    self.graph.add_node(router_name, type='edge_router', group=group, router=router)
                    self.routers.append(router_name)
                    
                    # Connect server to router
                    self.graph.add_edge(server_name, router_name)
                    
                    server_id += 1

        for group in range(num_groups):
            for router in range(group_size):
                aggregate_router_name = f"aggregate_router_{group}_{router}"
                self.graph.add_node(aggregate_router_name, type='aggregate_router', group=group, router=router)
                self.routers.append(aggregate_router_name)
                
                for edge_router in range(group_size):
                    edge_router_name = f"edge_router_{group}_{edge_router}"
                    self.graph.add_edge(aggregate_router_name, edge_router_name)


        if self.share_connects_per_spine == 1.0:
            # for hops measurement it fine to have only one spine router
            spine_router_name = "spine_router" 
            self.graph.add_node(spine_router_name, type='router', group=num_groups, router=group_size)
            
            for group in range(num_groups):
                for router in range(group_size):
                    aggregate_router_name = f"aggregate_router_{group}_{router}"
                    # Connect edge router to spine router
                    self.graph.add_edge(spine_router_name, aggregate_router_name)
        else:
            # Two-level spine hierarchy
            # Level 3 (Level 1 spine): spine_router_* connect to groups of aggregate_router_*
            # Level 4 (Level 2 spine): top_spine connects to all spine_router_*
            
            # Get all existing aggregate routers
            aggregate_routers_list = [n for n in self.routers if n.startswith('aggregate_router_')]
            
            # Calculate group size for each spine router
            aggregates_per_spine = int(len(aggregate_routers_list) * self.share_connects_per_spine)
            aggregates_per_spine = max(1, aggregates_per_spine)  # At least 1
            
            # Calculate number of spine routers needed
            num_spine_routers = max(1, (len(aggregate_routers_list) + aggregates_per_spine - 1) // aggregates_per_spine)
            
            # Create Level 1 spine routers (spine_router_*)
            for spine_id in range(num_spine_routers):
                spine_router_name = f"spine_router_{spine_id}"
                self.graph.add_node(spine_router_name, type='spine_router', group=num_groups + spine_id, router=spine_id)
                self.routers.append(spine_router_name)
            
            # Create Level 2 spine (top_spine)
            top_spine_router_name = "top_spine_router"
            self.graph.add_node(top_spine_router_name, type='top_spine_router', group=num_groups + num_spine_routers, router=0)
            self.routers.append(top_spine_router_name)
            
            # Connect aggregate routers to spine routers in groups
            for agg_idx, agg_router_name in enumerate(aggregate_routers_list):
                spine_id = agg_idx // aggregates_per_spine
                spine_router_name = f"spine_router_{spine_id}"
                self.graph.add_edge(spine_router_name, agg_router_name)
            
            # Connect all spine routers to top-level spine
            for spine_id in range(num_spine_routers):
                spine_router_name = f"spine_router_{spine_id}"
                self.graph.add_edge(top_spine_router_name, spine_router_name) 
                