#!/usr/bin/env python3
"""
Create subgraphs from large Cisco graphs using various sampling strategies.
This allows training on portions of large graphs while preserving structure.
"""

import pickle
import json
import argparse
import numpy as np
import torch
import networkx as nx
from pathlib import Path
from typing import List, Dict, Any, Optional
from tqdm import tqdm
import sys

# Add src to path
sys.path.append('src')

from gatv2_ns3_ids.utils.common import GraphData, get_logger, ensure_dir

class CiscoSubgraphSampler:
    """Sample subgraphs from large Cisco enterprise networks."""
    
    def __init__(self, target_nodes: int = 500, target_edges: int = 2000):
        self.target_nodes = target_nodes
        self.target_edges = target_edges
        self.logger = get_logger("cisco_subgraph_sampler")
    
    def sample_random_walk_subgraph(self, graph_dict: Dict, num_walks: int = 10, walk_length: int = 50, seed: int = 42) -> Optional[GraphData]:
        """Sample subgraph using random walks - preserves local structure."""
        np.random.seed(seed)
        
        x = graph_dict['x']
        edge_index = graph_dict['edge_index']
        edge_attr = graph_dict['edge_attr']
        y_node = graph_dict['y_node']
        
        num_nodes = x.shape[0]
        if num_nodes <= self.target_nodes:
            return self._dict_to_graphdata(graph_dict)
        
        # Convert to NetworkX for random walks
        G = nx.Graph()
        G.add_nodes_from(range(num_nodes))
        
        edge_list = edge_index.t().tolist()
        G.add_edges_from(edge_list)
        
        # Perform random walks from random starting nodes
        visited_nodes = set()
        
        for _ in range(num_walks):
            if len(visited_nodes) >= self.target_nodes:
                break
                
            # Random starting node
            start_node = np.random.randint(0, num_nodes)
            
            # Random walk
            current_node = start_node
            for _ in range(walk_length):
                visited_nodes.add(current_node)
                
                if len(visited_nodes) >= self.target_nodes:
                    break
                
                # Get neighbors
                neighbors = list(G.neighbors(current_node))
                if not neighbors:
                    break
                
                # Move to random neighbor
                current_node = np.random.choice(neighbors)
        
        # Extract subgraph
        selected_nodes = list(visited_nodes)[:self.target_nodes]
        return self._extract_subgraph(graph_dict, selected_nodes)
    
    def sample_degree_based_subgraph(self, graph_dict: Dict, seed: int = 42) -> Optional[GraphData]:
        """Sample high-degree nodes and their neighborhoods."""
        np.random.seed(seed)
        
        x = graph_dict['x']
        edge_index = graph_dict['edge_index']
        
        num_nodes = x.shape[0]
        if num_nodes <= self.target_nodes:
            return self._dict_to_graphdata(graph_dict)
        
        # Calculate node degrees
        degrees = torch.zeros(num_nodes)
        for i in range(edge_index.shape[1]):
            src, dst = edge_index[0, i].item(), edge_index[1, i].item()
            degrees[src] += 1
            degrees[dst] += 1
        
        # Select high-degree nodes
        k_high_degree = int(self.target_nodes * 0.6)  # 60% high degree
        k_random = self.target_nodes - k_high_degree   # 40% random
        
        # Top degree nodes
        _, top_indices = torch.topk(degrees, k=min(k_high_degree, num_nodes))
        selected_nodes = set(top_indices.tolist())
        
        # Add neighbors of high-degree nodes
        G = nx.Graph()
        edge_list = edge_index.t().tolist()
        G.add_edges_from(edge_list)
        
        for node in list(selected_nodes):
            neighbors = list(G.neighbors(node))
            for neighbor in neighbors[:2]:  # Add up to 2 neighbors per high-degree node
                if len(selected_nodes) < self.target_nodes:
                    selected_nodes.add(neighbor)
        
        # Fill remaining with random nodes
        remaining_nodes = set(range(num_nodes)) - selected_nodes
        if remaining_nodes and len(selected_nodes) < self.target_nodes:
            additional_needed = self.target_nodes - len(selected_nodes)
            additional_nodes = np.random.choice(
                list(remaining_nodes), 
                size=min(additional_needed, len(remaining_nodes)), 
                replace=False
            )
            selected_nodes.update(additional_nodes)
        
        return self._extract_subgraph(graph_dict, list(selected_nodes))
    
    def sample_connected_component_subgraph(self, graph_dict: Dict, seed: int = 42) -> Optional[GraphData]:
        """Sample from largest connected components."""
        np.random.seed(seed)
        
        x = graph_dict['x']
        edge_index = graph_dict['edge_index']
        
        num_nodes = x.shape[0]
        if num_nodes <= self.target_nodes:
            return self._dict_to_graphdata(graph_dict)
        
        # Convert to NetworkX
        G = nx.Graph()
        G.add_nodes_from(range(num_nodes))
        edge_list = edge_index.t().tolist()
        G.add_edges_from(edge_list)
        
        # Find connected components
        components = list(nx.connected_components(G))
        components.sort(key=len, reverse=True)  # Largest first
        
        # Select nodes from largest components
        selected_nodes = set()
        
        for component in components:
            if len(selected_nodes) >= self.target_nodes:
                break
            
            # Add nodes from this component
            component_nodes = list(component)
            needed = self.target_nodes - len(selected_nodes)
            
            if len(component_nodes) <= needed:
                selected_nodes.update(component_nodes)
            else:
                # Sample from this component
                sampled = np.random.choice(component_nodes, size=needed, replace=False)
                selected_nodes.update(sampled)
        
        return self._extract_subgraph(graph_dict, list(selected_nodes))
    
    def _dict_to_graphdata(self, graph_dict: Dict) -> GraphData:
        """Convert graph dictionary to GraphData object."""
        return GraphData(
            x=graph_dict['x'],
            edge_index=graph_dict['edge_index'],
            edge_attr=graph_dict['edge_attr'],
            y_node=graph_dict['y_node'],
            graph_id=graph_dict.get('graph_id', 'cisco_subgraph'),
            window_idx=graph_dict.get('window_idx', 0)
        )
    
    def _extract_subgraph(self, graph_dict: Dict, node_indices: List[int]) -> GraphData:
        """Extract subgraph with given node indices."""
        # Create node mapping
        old_to_new = {old_idx: new_idx for new_idx, old_idx in enumerate(node_indices)}
        
        # Extract node features and labels
        x = graph_dict['x']
        y_node = graph_dict['y_node']
        edge_index = graph_dict['edge_index']
        edge_attr = graph_dict['edge_attr']
        
        new_x = x[node_indices]
        new_y_node = y_node[node_indices] if y_node is not None else None
        
        # Extract edges that connect selected nodes
        new_edges = []
        new_edge_attrs = []
        
        for i in range(edge_index.shape[1]):
            src, dst = edge_index[0, i].item(), edge_index[1, i].item()
            
            if src in old_to_new and dst in old_to_new:
                new_src = old_to_new[src]
                new_dst = old_to_new[dst]
                new_edges.append([new_src, new_dst])
                
                if edge_attr is not None:
                    new_edge_attrs.append(edge_attr[i])
        
        # Convert to tensors
        if new_edges:
            new_edge_index = torch.tensor(new_edges, dtype=torch.long).t()
            if new_edge_attrs:
                new_edge_attr = torch.stack(new_edge_attrs)
            else:
                new_edge_attr = None
        else:
            new_edge_index = torch.zeros(2, 0, dtype=torch.long)
            new_edge_attr = None
        
        # Create new GraphData
        return GraphData(
            x=new_x,
            edge_index=new_edge_index,
            edge_attr=new_edge_attr,
            y_node=new_y_node,
            graph_id=graph_dict.get('graph_id', 'cisco_subgraph'),
            window_idx=graph_dict.get('window_idx', 0)
        )

def create_cisco_subgraphs(
    input_path: str = "data/cisco_processed/cisco_graphs_processed.pkl",
    output_path: str = "data/cisco_subgraphs",
    target_nodes: int = 500,
    target_edges: int = 2000,
    sampling_strategy: str = "random_walk",
    subgraphs_per_graph: int = 3,
    seed: int = 42
):
    """Create subgraphs from large Cisco graphs."""
    
    logger = get_logger("create_cisco_subgraphs")
    ensure_dir(output_path)
    
    # Load preprocessed data
    logger.info(f"Loading preprocessed data from {input_path}")
    with open(input_path, 'rb') as f:
        graphs_data = pickle.load(f)
    
    logger.info(f"Loaded {len(graphs_data)} graphs")
    
    # Create sampler
    sampler = CiscoSubgraphSampler(target_nodes=target_nodes, target_edges=target_edges)
    
    # Generate subgraphs
    all_subgraphs = []
    subgraph_info = []
    
    pbar = tqdm(graphs_data, desc="Creating subgraphs")
    
    for graph_idx, graph_dict in enumerate(pbar):
        original_nodes = graph_dict['num_nodes']
        original_edges = graph_dict['num_edges']
        
        # Skip if already small enough
        if original_nodes <= target_nodes:
            subgraph = sampler._dict_to_graphdata(graph_dict)
            all_subgraphs.append(subgraph)
            subgraph_info.append({
                'original_graph_id': graph_idx,
                'subgraph_id': 0,
                'original_nodes': original_nodes,
                'original_edges': original_edges,
                'subgraph_nodes': original_nodes,
                'subgraph_edges': original_edges,
                'sampling_method': 'none_needed'
            })
            continue
        
        # Create multiple subgraphs from large graphs
        for sub_idx in range(subgraphs_per_graph):
            try:
                if sampling_strategy == "random_walk":
                    subgraph = sampler.sample_random_walk_subgraph(graph_dict, seed=seed + graph_idx + sub_idx)
                elif sampling_strategy == "degree_based":
                    subgraph = sampler.sample_degree_based_subgraph(graph_dict, seed=seed + graph_idx + sub_idx)
                elif sampling_strategy == "connected_component":
                    subgraph = sampler.sample_connected_component_subgraph(graph_dict, seed=seed + graph_idx + sub_idx)
                else:
                    raise ValueError(f"Unknown sampling strategy: {sampling_strategy}")
                
                if subgraph is not None:
                    all_subgraphs.append(subgraph)
                    subgraph_info.append({
                        'original_graph_id': graph_idx,
                        'subgraph_id': sub_idx,
                        'original_nodes': original_nodes,
                        'original_edges': original_edges,
                        'subgraph_nodes': subgraph.x.shape[0],
                        'subgraph_edges': subgraph.edge_index.shape[1],
                        'sampling_method': sampling_strategy
                    })
                    
            except Exception as e:
                logger.warning(f"Failed to create subgraph {sub_idx} from graph {graph_idx}: {e}")
        
        pbar.set_postfix({'Subgraphs': len(all_subgraphs)})
    
    logger.info(f"Created {len(all_subgraphs)} subgraphs")
    
    # Save subgraphs
    output_file = Path(output_path) / "cisco_subgraphs.pkl"
    with open(output_file, 'wb') as f:
        pickle.dump(all_subgraphs, f)
    
    # Save metadata
    metadata = {
        "num_subgraphs": len(all_subgraphs),
        "target_nodes": target_nodes,
        "target_edges": target_edges,
        "sampling_strategy": sampling_strategy,
        "subgraphs_per_graph": subgraphs_per_graph,
        "seed": seed,
        "source": input_path,
        "subgraph_info": subgraph_info,
        "statistics": {
            "total_nodes": sum(info['subgraph_nodes'] for info in subgraph_info),
            "total_edges": sum(info['subgraph_edges'] for info in subgraph_info),
            "avg_nodes": sum(info['subgraph_nodes'] for info in subgraph_info) / len(subgraph_info) if subgraph_info else 0,
            "avg_edges": sum(info['subgraph_edges'] for info in subgraph_info) / len(subgraph_info) if subgraph_info else 0,
            "max_nodes": max(info['subgraph_nodes'] for info in subgraph_info) if subgraph_info else 0,
            "max_edges": max(info['subgraph_edges'] for info in subgraph_info) if subgraph_info else 0,
        }
    }
    
    metadata_file = Path(output_path) / "metadata.json"
    with open(metadata_file, 'w') as f:
        json.dump(metadata, f, indent=2)
    
    logger.info(f"Saved {len(all_subgraphs)} subgraphs to {output_file}")
    logger.info(f"Saved metadata to {metadata_file}")
    
    # Print summary
    if all_subgraphs:
        stats = metadata["statistics"]
        print(f"\n📊 Cisco Subgraphs Dataset Summary:")
        print(f"  Subgraphs: {len(all_subgraphs)}")
        print(f"  Total nodes: {stats['total_nodes']:,}")
        print(f"  Total edges: {stats['total_edges']:,}")
        print(f"  Avg nodes per subgraph: {stats['avg_nodes']:.1f}")
        print(f"  Avg edges per subgraph: {stats['avg_edges']:.1f}")
        print(f"  Max nodes: {stats['max_nodes']}")
        print(f"  Max edges: {stats['max_edges']}")
        print(f"  Sampling strategy: {sampling_strategy}")
        print(f"\n✅ Ready for scalable training!")
    
    return all_subgraphs

def main():
    parser = argparse.ArgumentParser(description="Create subgraphs from large Cisco graphs")
    parser.add_argument("--input_path", default="data/cisco_processed/cisco_graphs_processed.pkl",
                       help="Input preprocessed dataset")
    parser.add_argument("--output_path", default="data/cisco_subgraphs", 
                       help="Output directory")
    parser.add_argument("--target_nodes", type=int, default=500,
                       help="Target nodes per subgraph")
    parser.add_argument("--target_edges", type=int, default=2000,
                       help="Target edges per subgraph")
    parser.add_argument("--strategy", choices=["random_walk", "degree_based", "connected_component"],
                       default="random_walk", help="Sampling strategy")
    parser.add_argument("--subgraphs_per_graph", type=int, default=3,
                       help="Number of subgraphs per large graph")
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    
    args = parser.parse_args()
    
    create_cisco_subgraphs(
        input_path=args.input_path,
        output_path=args.output_path,
        target_nodes=args.target_nodes,
        target_edges=args.target_edges,
        sampling_strategy=args.strategy,
        subgraphs_per_graph=args.subgraphs_per_graph,
        seed=args.seed
    )

if __name__ == "__main__":
    main()
