#!/usr/bin/env python3
"""
Complete baseline communication calculator for EP optimization
Calculate both inter-node and intra-node communication for sequential assignment strategy
Implements DeepEP-style communication pattern with node-level routing and intra-node all-to-all
"""

import argparse
import json
import logging
import numpy as np
from pathlib import Path
from typing import Dict, List, Tuple, Set
import sys

# Add project root to path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))

# Add Megatron path
megatron_path = project_root / "Megatron-LM"
if str(megatron_path) not in sys.path:
    sys.path.insert(0, str(megatron_path))

from megatron.core.datasets import indexed_dataset

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)


class CompleteBaselineCommunicationCalculator:
    """Calculate complete baseline communication for sequential assignment with DeepEP-style routing"""
    
    def __init__(self, num_nodes: int = 4, gpus_per_node: int = 8, experts_per_gpu: int = 2, a2a_mode=False):
        self.num_nodes = num_nodes
        self.gpus_per_node = gpus_per_node
        self.experts_per_gpu = experts_per_gpu
        self.total_gpus = num_nodes * gpus_per_node
        self.a2a_mode=a2a_mode
        
        # Expert distribution: evenly distribute experts across all GPUs
        self.expert_to_gpu = {}
        self.gpu_to_experts = {}
        self.node_gpu_tuple = [(node_id, gpu_id) for node_id in range(self.num_nodes) for gpu_id in range(self.gpus_per_node)]
        
        # Initialize GPU to experts mapping
        for node_id in range(self.num_nodes):
            for gpu_id in range(self.gpus_per_node):
                self.gpu_to_experts[(node_id, gpu_id)] = set()
        
    def load_dispatch_data(self, dispatch_path: str) -> indexed_dataset.IndexedDataset:
        """Load dispatch data from binary files"""
        try:
            dataset = indexed_dataset.IndexedDataset(dispatch_path)
            logger.info(f"Loaded dispatch dataset with {len(dataset)} documents")
            return dataset
            
        except Exception as e:
            logger.error(f"Failed to load dispatch data: {e}")
            raise
    
    def setup_expert_distribution(self, num_experts: int):
        """Setup expert distribution across GPUs"""
        experts_per_gpu = num_experts // self.total_gpus
        remaining_experts = num_experts % self.total_gpus
        
        expert_id = 0
        for node_id in range(self.num_nodes):
            for gpu_id in range(self.gpus_per_node):
                gpu_key = (node_id, gpu_id)
                
                # Assign base experts per GPU
                for _ in range(experts_per_gpu):
                    self.expert_to_gpu[expert_id] = gpu_key
                    self.gpu_to_experts[gpu_key].add(expert_id)
                    expert_id += 1
                
                # Assign remaining experts to first few GPUs
                if expert_id < num_experts and len(self.gpu_to_experts[gpu_key]) < experts_per_gpu:
                    self.expert_to_gpu[expert_id] = gpu_key
                    self.gpu_to_experts[gpu_key].add(expert_id)
                    expert_id += 1
        
        logger.info(f"Distributed {num_experts} experts across {self.total_gpus} GPUs")
        logger.info(f"Experts per GPU: {experts_per_gpu} (base) + {1 if remaining_experts > 0 else 0} (remaining)")
    
    def assign_samples_to_gpus(self, samples: List[List[int]], micro_batch_size: int) -> Dict[Tuple[int, int], List[List[int]]]:
        """Assign samples to GPUs sequentially"""
        gpu_assignments = {gpu_key: [] for gpu_key in self.node_gpu_tuple}
        
        # Calculate samples per GPU per micro batch
        samples_per_gpu_per_micro_batch = micro_batch_size // self.gpus_per_node
        
        sample_idx = 0
        while sample_idx < len(samples):
            # Assign samples to each GPU in round-robin fashion
            for gpu_key in self.node_gpu_tuple:
                if sample_idx >= len(samples):
                    break
                    
                # Get samples for this GPU in this micro batch
                end_idx = min(sample_idx + samples_per_gpu_per_micro_batch, len(samples))
                gpu_samples = samples[sample_idx:end_idx]
                gpu_assignments[gpu_key].append(gpu_samples)
                sample_idx = end_idx
        
        # Log assignment statistics
        for gpu_key, gpu_samples in gpu_assignments.items():
            total_samples = sum(len(micro_batch) for micro_batch in gpu_samples)
            logger.info(f"GPU {gpu_key}: assigned {total_samples} samples in {len(gpu_samples)} micro batches")
        
        return gpu_assignments
    
    def calculate_inter_node_communication(self, gpu_assignments: Dict[Tuple[int, int], List[List[int]]], 
                                         topk: int) -> Dict[str, float]:
        """Calculate inter-node communication using DeepEP-style routing"""
        
        total_inter_node_comm = 0
        total_tokens = 0
        
        # Track communication per node pair
        node_pair_comm = {}
        node_expert = {nid:[] for nid in range(self.num_nodes)}
        for nid in range(self.num_nodes):
            for e, (node_id, _) in self.expert_to_gpu.items():
                if node_id == nid:
                    node_expert[nid].append(e)
        
        for source_gpu_key, micro_batches in gpu_assignments.items():
            source_node, source_gpu = source_gpu_key
            
            for micro_batch in micro_batches:
                for sample in micro_batch:
                    # Reshape dispatch IDs to [seq_len, topk]
                    dispatch_matrix = np.array(sample, dtype=np.uint8).reshape(-1, topk)
                    sample_tokens = len(dispatch_matrix)
                    total_tokens += sample_tokens
                    
                    # Calculate inter-node communication for this sample
                    for target_node in range(self.num_nodes):
                        if source_node != target_node:
                            experts_in_target_gpu = node_expert[target_node]
                            if self.a2a_mode:
                                inter_node_comm = np.isin(dispatch_matrix, list(experts_in_target_gpu)).sum()
                            else:
                                inter_node_comm = np.any(np.isin(dispatch_matrix, list(experts_in_target_gpu)), axis=1).sum()
                            pair_key = (source_node, target_node)
                            if pair_key not in node_pair_comm:
                                node_pair_comm[pair_key] = 0
                            node_pair_comm[pair_key] += int(inter_node_comm)
                            total_inter_node_comm += inter_node_comm
        
        # Calculate communication ratios
        inter_node_ratio = total_inter_node_comm / total_tokens if total_tokens > 0 else 0
        
        return {
            'total_inter_node_communication': int(total_inter_node_comm),
            'inter_node_ratio': float(inter_node_ratio),
            'total_tokens': int(total_tokens),
            'node_pair_communication': node_pair_comm,
        }
    
    def calculate_intra_node_communication(self, gpu_assignments: Dict[Tuple[int, int], List[List[int]]], 
                                         topk: int) -> Dict[str, float]:
        """Calculate intra-node communication using all-to-all pattern"""
        
        total_intra_node_comm = 0
        total_tokens = 0
        
        # Track communication per node
        node_intra_comm = {node_id: 0 for node_id in range(self.num_nodes)}
        
        for source_gpu_key, micro_batches in gpu_assignments.items():
            source_node, source_gpu = source_gpu_key
            
            for micro_batch in micro_batches:
                for sample in micro_batch:
                    # Reshape dispatch IDs to [seq_len, topk]
                    dispatch_matrix = np.array(sample, dtype=np.uint8).reshape(-1, topk)
                    sample_tokens = len(dispatch_matrix)
                    total_tokens += sample_tokens
                    
                    # Calculate intra-node communication for this sample
                    for target_gpu_key in self.node_gpu_tuple:
                        target_node, target_gpu = target_gpu_key
                        
                        # Skip same GPU
                        if source_gpu_key == target_gpu_key:
                            continue
                        
                        # Only consider intra-node communication
                        if source_gpu != target_gpu:
                            experts_in_target_gpu = self.gpu_to_experts[target_gpu_key]
                            if self.a2a_mode:
                                intra_node_comm = np.isin(dispatch_matrix, list(experts_in_target_gpu)).sum()
                            else:
                                intra_node_comm = np.any(np.isin(dispatch_matrix, list(experts_in_target_gpu)), axis=1).sum()
                            
                            node_intra_comm[source_node] += intra_node_comm
                            total_intra_node_comm += intra_node_comm
        
        # Calculate communication ratios
        intra_node_ratio = total_intra_node_comm / total_tokens if total_tokens > 0 else 0
        
        return {
            'total_intra_node_communication': int(total_intra_node_comm),
            'intra_node_ratio': float(intra_node_ratio),
            'total_tokens': int(total_tokens),
            'node_intra_communication': node_intra_comm,
        }
    
    def calculate_gpu_load_balance(self, gpu_assignments: Dict[Tuple[int, int], List[List[int]]], 
                                 topk: int) -> Dict[str, float]:
        """Calculate GPU load balance and entropy"""
        
        gpu_loads = {}
        
        for gpu_key, micro_batches in gpu_assignments.items():
            gpu_load = 0
            
            for micro_batch in micro_batches:
                for sample in micro_batch:
                    # Reshape dispatch IDs to [seq_len, topk]
                    dispatch_matrix = np.array(sample, dtype=np.uint8).reshape(-1, topk)
                    
                    # Count tokens that need this GPU's experts
                    experts_in_gpu = self.gpu_to_experts[gpu_key]
                    gpu_load += np.isin(dispatch_matrix, list(experts_in_gpu)).sum()
            
            gpu_loads[gpu_key] = gpu_load
        
        # Calculate load balance entropy
        if gpu_loads:
            load_values = np.array(list(gpu_loads.values()))
            load_ratios = load_values / np.sum(load_values)
            # Avoid log(0) by adding small epsilon
            load_ratios = np.where(load_ratios > 0, load_ratios, 1e-10)
            entropy = -np.sum(load_ratios * np.log2(load_ratios))
        else:
            entropy = 0.0
        
        return {
            'gpu_loads': gpu_loads,
            'load_balance_entropy': float(entropy),
            'load_balance_ratio': float(np.min(load_values) / np.max(load_values)) if len(load_values) > 0 else 0.0
        }
    
    def print_communication_summary(self, inter_stats: Dict, intra_stats: Dict, load_stats: Dict):
        """Print comprehensive communication summary"""
        print("\n" + "="*80)
        print("COMPLETE BASELINE COMMUNICATION ANALYSIS SUMMARY")
        print("="*80)
        
        # Overall statistics
        total_comm = inter_stats['total_inter_node_communication'] + intra_stats['total_intra_node_communication']
        total_comm_ratio = total_comm / inter_stats['total_tokens'] if inter_stats['total_tokens'] > 0 else 0
        
        print(f"Total tokens processed: {inter_stats['total_tokens']:,}")
        print(f"Total communication: {total_comm:,} tokens ({total_comm_ratio:.2%})")
        print(f"  - Inter-node communication: {inter_stats['total_inter_node_communication']:,} tokens ({inter_stats['inter_node_ratio']:.2%})")
        print(f"  - Intra-node communication: {intra_stats['total_intra_node_communication']:,} tokens ({intra_stats['intra_node_ratio']:.2%})")
        
        # Load balance statistics
        print(f"\nLoad Balance Statistics:")
        print(f"  - Load balance entropy: {load_stats['load_balance_entropy']:.4f}")
        print(f"  - Load balance ratio: {load_stats['load_balance_ratio']:.4f}")
        
        # Inter-node communication breakdown
        print(f"\nInter-node communication breakdown:")
        for (node1, node2), comm_count in inter_stats['node_pair_communication'].items():
            print(f"  Node {node1} → Node {node2}: {comm_count:,} tokens")
        
        # Intra-node communication breakdown
        print(f"\nIntra-node communication breakdown:")
        for node_id, comm_count in intra_stats['node_intra_communication'].items():
            print(f"  Node {node_id}: {comm_count:,} tokens")
        
        print("="*80)
    
    def run_complete_analysis(self, dispatch_path: str, num_experts: int, micro_batch_size: int, 
                            topk: int, num_batch) -> Dict[str, Dict]:
        """Run complete communication analysis including both inter and intra-node communication"""
        logger.info("Starting complete baseline communication analysis")
        
        # Setup expert distribution
        self.setup_expert_distribution(num_experts)
        
        # Load data
        dispatch_data = self.load_dispatch_data(dispatch_path)
        
        # Convert to list for easier processing
        samples = [dispatch_data[i] for i in range(num_batch)]
        
        # Assign samples to GPUs
        gpu_assignments = self.assign_samples_to_gpus(samples, micro_batch_size)
        
        # Calculate inter-node communication
        logger.info("Calculating inter-node communication...")
        inter_stats = self.calculate_inter_node_communication(gpu_assignments, topk)
        
        # Calculate intra-node communication
        logger.info("Calculating intra-node communication...")
        intra_stats = self.calculate_intra_node_communication(gpu_assignments, topk)
        
        # Calculate load balance
        logger.info("Calculating GPU load balance...")
        load_stats = self.calculate_gpu_load_balance(gpu_assignments, topk)
        
        # Print summary
        self.print_communication_summary(inter_stats, intra_stats, load_stats)
        del inter_stats['node_pair_communication']
        
        return {
            'inter_node_communication': inter_stats,
            'intra_node_communication': intra_stats,
            'load_balance': load_stats,
            'configuration': {
                'num_nodes': self.num_nodes,
                'gpus_per_node': self.gpus_per_node,
                'experts_per_gpu': self.experts_per_gpu,
                'total_gpus': self.total_gpus,
                'num_experts': num_experts,
                'micro_batch_size': micro_batch_size,
                'topk': topk
            }
        }


def main():
    """Main function"""
    parser = argparse.ArgumentParser(
        description='Calculate complete baseline communication for sequential assignment with DeepEP-style routing'
    )
    
    parser.add_argument('--dispatch_path', required=True, type=str,
                       help='Path to dispatch file')
    parser.add_argument('--num_experts', required=True, type=int,
                       help='Total number of experts')
    parser.add_argument('--micro_batch_size', required=True, type=int,
                       help='Micro batch size per node')
    parser.add_argument('--topk', type=int, required=True,
                       help='The number each token routed to')
    parser.add_argument('--num_nodes', type=int, default=4,
                       help='Number of nodes (default: 4)')
    parser.add_argument('--gpus_per_node', type=int, default=8,
                       help='GPUs per node (default: 8)')
    parser.add_argument('--experts_per_gpu', type=int, default=2,
                       help='Experts per GPU (default: 2)')
    parser.add_argument('--log_level', type=str, default='INFO',
                       choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'],
                       help='Logging level')
    parser.add_argument("--a2a-mode", action="store_true")
    
    args = parser.parse_args()
    
    # Set logging level
    logging.getLogger().setLevel(getattr(logging, args.log_level))
    
    # Initialize calculator
    calculator = CompleteBaselineCommunicationCalculator(
        num_nodes=args.num_nodes,
        gpus_per_node=args.gpus_per_node,
        experts_per_gpu=args.experts_per_gpu,
        a2a_mode=args.a2a_mode
    )
        
    # Run complete analysis
    results = calculator.run_complete_analysis(
        dispatch_path=args.dispatch_path,
        num_experts=args.num_experts,
        micro_batch_size=args.micro_batch_size,
        topk=args.topk,
        num_batch=3200
    )

if __name__ == '__main__':
    exit(main())
