#!/usr/bin/env python3
"""
Create smaller, manageable subgraphs from the large Cisco dataset.
Provides multiple sampling strategies for different use cases.
"""

import os
import sys
import pickle
import argparse
import numpy as np
import torch
import networkx as nx
from pathlib import Path
from typing import List, Tuple, Dict, Optional
from tqdm import tqdm

# Add src to path
sys.path.append('src')

from gatv2_ns3_ids.data.cisco_dataset import CiscoDatasetLoader
from gatv2_ns3_ids.utils.common import GraphData, get_logger, ensure_dir

class CiscoGraphSampler:
    """Sample smaller subgraphs from large Cisco enterprise networks."""
    
    def __init__(self, max_nodes: int = 500, max_edges: int = 2000):
        self.max_nodes = max_nodes
        self.max_edges = max_edges
        self.logger = get_logger("cisco_sampler")
    
    def sample_subgraph_random(self, graph: GraphData, seed: int = 42) -> Optional[GraphData]:
        """Random node sampling - fastest method."""
        if graph.x.shape[0] == 0:
            return None
            
        np.random.seed(seed)
        
        # If graph is already small enough, return as-is
        if graph.x.shape[0] <= self.max_nodes:
            return graph
        
        # Random sample of nodes
        node_indices = np.random.choice(
            graph.x.shape[0], 
            size=min(self.max_nodes, graph.x.shape[0]), 
            replace=False
        )
        node_indices = np.sort(node_indices)
        
        return self._extract_subgraph(graph, node_indices)
    
    def sample_subgraph_degree_based(self, graph: GraphData, seed: int = 42) -> Optional[GraphData]:
        """Sample high-degree nodes (network hubs) - preserves important nodes."""
        if graph.x.shape[0] == 0:
            return None
            
        np.random.seed(seed)
        
        # If graph is already small enough, return as-is
        if graph.x.shape[0] <= self.max_nodes:
            return graph
        
        # Calculate node degrees
        edge_index = graph.edge_index
        degrees = torch.zeros(graph.x.shape[0])
        
        for i in range(edge_index.shape[1]):
            src, dst = edge_index[0, i].item(), edge_index[1, i].item()
            degrees[src] += 1
            degrees[dst] += 1
        
        # Select top-k high degree nodes + some random nodes
        k_high_degree = int(self.max_nodes * 0.7)  # 70% high degree
        k_random = self.max_nodes - k_high_degree   # 30% random
        
        # Top degree nodes
        _, top_indices = torch.topk(degrees, k=min(k_high_degree, graph.x.shape[0]))
        
        # Random nodes (excluding already selected)
        remaining_nodes = set(range(graph.x.shape[0])) - set(top_indices.tolist())
        if remaining_nodes and k_random > 0:
            random_indices = np.random.choice(
                list(remaining_nodes), 
                size=min(k_random, len(remaining_nodes)), 
                replace=False
            )
            node_indices = np.concatenate([top_indices.numpy(), random_indices])
        else:
            node_indices = top_indices.numpy()
        
        node_indices = np.sort(node_indices)
        return self._extract_subgraph(graph, node_indices)
    
    def sample_subgraph_connected_component(self, graph: GraphData, seed: int = 42) -> Optional[GraphData]:
        """Sample largest connected component - preserves graph structure."""
        if graph.x.shape[0] == 0:
            return None
            
        np.random.seed(seed)
        
        # If graph is already small enough, return as-is
        if graph.x.shape[0] <= self.max_nodes:
            return graph
        
        # Convert to NetworkX
        G = nx.Graph()
        G.add_nodes_from(range(graph.x.shape[0]))
        
        edge_list = graph.edge_index.t().tolist()
        G.add_edges_from(edge_list)
        
        # Find connected components
        components = list(nx.connected_components(G))
        components.sort(key=len, reverse=True)  # Largest first
        
        # Select largest component that fits our constraints
        selected_nodes = None
        for component in components:
            if len(component) <= self.max_nodes:
                selected_nodes = list(component)
                break
        
        # If no component fits, take largest and subsample
        if selected_nodes is None:
            largest_component = list(components[0])
            selected_nodes = np.random.choice(
                largest_component, 
                size=self.max_nodes, 
                replace=False
            ).tolist()
        
        node_indices = np.sort(selected_nodes)
        return self._extract_subgraph(graph, node_indices)
    
    def _extract_subgraph(self, graph: GraphData, node_indices: np.ndarray) -> GraphData:
        """Extract subgraph with given node indices."""
        # Create node mapping
        old_to_new = {old_idx: new_idx for new_idx, old_idx in enumerate(node_indices)}
        
        # Extract node features and labels
        new_x = graph.x[node_indices]
        new_y_node = graph.y_node[node_indices] if graph.y_node is not None else None
        
        # Extract edges that connect selected nodes
        edge_index = graph.edge_index
        edge_attr = graph.edge_attr
        
        new_edges = []
        new_edge_attrs = []
        
        for i in range(edge_index.shape[1]):
            src, dst = edge_index[0, i].item(), edge_index[1, i].item()
            
            if src in old_to_new and dst in old_to_new:
                new_src = old_to_new[src]
                new_dst = old_to_new[dst]
                new_edges.append([new_src, new_dst])
                
                if edge_attr is not None:
                    new_edge_attrs.append(edge_attr[i])
        
        # Convert to tensors
        if new_edges:
            new_edge_index = torch.tensor(new_edges, dtype=torch.long).t()
            if new_edge_attrs:
                new_edge_attr = torch.stack(new_edge_attrs)
            else:
                new_edge_attr = None
        else:
            new_edge_index = torch.zeros(2, 0, dtype=torch.long)
            new_edge_attr = None
        
        # Create new GraphData
        return GraphData(
            x=new_x,
            edge_index=new_edge_index,
            edge_attr=new_edge_attr,
            y_node=new_y_node,
            graph_id=graph.graph_id,
            window_idx=graph.window_idx
        )

def create_small_cisco_dataset(
    cisco_data_path: str = "data/cisco",
    output_path: str = "data/cisco_small",
    max_nodes: int = 500,
    max_edges: int = 2000,
    sampling_strategy: str = "degree_based",
    num_graphs: int = 10,
    seed: int = 42
):
    """Create a dataset of small Cisco graphs."""
    
    logger = get_logger("create_small_cisco")
    ensure_dir(output_path)
    
    # Load original dataset
    logger.info(f"Loading Cisco dataset from {cisco_data_path}")
    loader = CiscoDatasetLoader(cisco_data_path)
    
    try:
        # Try to load a few graphs to see their sizes
        graphs = loader.load_enterprise_graphs()
        logger.info(f"Loaded {len(graphs)} enterprise graphs")
        
        # Log original sizes
        for i, graph in enumerate(graphs[:3]):
            nodes = graph.x.shape[0] if graph.x is not None else 0
            edges = graph.edge_index.shape[1] if graph.edge_index is not None else 0
            logger.info(f"Original graph {i}: {nodes:,} nodes, {edges:,} edges")
        
    except Exception as e:
        logger.error(f"Failed to load graphs: {e}")
        logger.info("Creating synthetic small graphs instead...")
        
        # Create synthetic graphs as fallback
        graphs = []
        np.random.seed(seed)
        
        for i in range(num_graphs):
            # Random graph size
            num_nodes = np.random.randint(50, max_nodes)
            
            # Create random graph
            G = nx.erdos_renyi_graph(num_nodes, p=0.1, seed=seed + i)
            
            # Convert to GraphData
            edges = list(G.edges())
            if edges:
                edge_index = torch.tensor(edges, dtype=torch.long).t()
                # Add reverse edges for undirected graph
                edge_index = torch.cat([edge_index, edge_index[[1, 0]]], dim=1)
            else:
                edge_index = torch.zeros(2, 0, dtype=torch.long)
            
            # Random node features (simulate network features)
            node_features = torch.randn(num_nodes, 10)  # 10 features
            
            # Random labels (simulate attack detection)
            node_labels = torch.randint(0, 2, (num_nodes,))  # Binary classification
            
            # Simple edge attributes
            num_edges = edge_index.shape[1]
            edge_attr = torch.randn(num_edges, 5) if num_edges > 0 else torch.zeros(0, 5)
            
            graph = GraphData(
                x=node_features,
                edge_index=edge_index,
                edge_attr=edge_attr,
                y_node=node_labels
            )
            graphs.append(graph)
        
        logger.info(f"Created {len(graphs)} synthetic graphs")
    
    # Sample smaller graphs
    sampler = CiscoGraphSampler(max_nodes=max_nodes, max_edges=max_edges)
    small_graphs = []
    
    logger.info(f"Sampling with strategy: {sampling_strategy}")
    
    pbar = tqdm(graphs[:num_graphs], desc="Sampling graphs")
    for i, graph in enumerate(pbar):
        if sampling_strategy == "random":
            small_graph = sampler.sample_subgraph_random(graph, seed=seed + i)
        elif sampling_strategy == "degree_based":
            small_graph = sampler.sample_subgraph_degree_based(graph, seed=seed + i)
        elif sampling_strategy == "connected_component":
            small_graph = sampler.sample_subgraph_connected_component(graph, seed=seed + i)
        else:
            raise ValueError(f"Unknown sampling strategy: {sampling_strategy}")
        
        if small_graph is not None:
            small_graphs.append(small_graph)
            nodes = small_graph.x.shape[0]
            edges = small_graph.edge_index.shape[1]
            pbar.set_postfix({'Nodes': nodes, 'Edges': edges})
    
    logger.info(f"Created {len(small_graphs)} small graphs")
    
    # Save dataset
    output_file = Path(output_path) / "cisco_graphs_small.pkl"
    with open(output_file, 'wb') as f:
        pickle.dump(small_graphs, f)
    
    # Save metadata
    metadata = {
        "num_graphs": len(small_graphs),
        "max_nodes": max_nodes,
        "max_edges": max_edges,
        "sampling_strategy": sampling_strategy,
        "seed": seed,
        "graph_stats": []
    }
    
    for i, graph in enumerate(small_graphs):
        stats = {
            "graph_id": i,
            "nodes": graph.x.shape[0],
            "edges": graph.edge_index.shape[1],
            "node_features": graph.x.shape[1],
            "edge_features": graph.edge_attr.shape[1] if graph.edge_attr is not None else 0,
            "attack_ratio": (graph.y_node == 1).float().mean().item() if graph.y_node is not None else 0.0
        }
        metadata["graph_stats"].append(stats)
    
    metadata_file = Path(output_path) / "metadata.json"
    import json
    with open(metadata_file, 'w') as f:
        json.dump(metadata, f, indent=2)
    
    logger.info(f"Saved dataset to {output_file}")
    logger.info(f"Saved metadata to {metadata_file}")
    
    # Print summary
    total_nodes = sum(g.x.shape[0] for g in small_graphs)
    total_edges = sum(g.edge_index.shape[1] for g in small_graphs)
    avg_nodes = total_nodes / len(small_graphs)
    avg_edges = total_edges / len(small_graphs)
    
    print(f"\n📊 Small Cisco Dataset Summary:")
    print(f"  Graphs: {len(small_graphs)}")
    print(f"  Total nodes: {total_nodes:,}")
    print(f"  Total edges: {total_edges:,}")
    print(f"  Avg nodes per graph: {avg_nodes:.1f}")
    print(f"  Avg edges per graph: {avg_edges:.1f}")
    print(f"  Max nodes: {max(g.x.shape[0] for g in small_graphs)}")
    print(f"  Max edges: {max(g.edge_index.shape[1] for g in small_graphs)}")
    
    return small_graphs

def main():
    parser = argparse.ArgumentParser(description="Create small Cisco graphs for efficient training")
    parser.add_argument("--cisco_data_path", default="data/cisco", help="Path to Cisco dataset")
    parser.add_argument("--output_path", default="data/cisco_small", help="Output path for small graphs")
    parser.add_argument("--max_nodes", type=int, default=500, help="Maximum nodes per graph")
    parser.add_argument("--max_edges", type=int, default=2000, help="Maximum edges per graph")
    parser.add_argument("--strategy", choices=["random", "degree_based", "connected_component"], 
                       default="degree_based", help="Sampling strategy")
    parser.add_argument("--num_graphs", type=int, default=10, help="Number of graphs to create")
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    
    args = parser.parse_args()
    
    create_small_cisco_dataset(
        cisco_data_path=args.cisco_data_path,
        output_path=args.output_path,
        max_nodes=args.max_nodes,
        max_edges=args.max_edges,
        sampling_strategy=args.strategy,
        num_graphs=args.num_graphs,
        seed=args.seed
    )

if __name__ == "__main__":
    main()
