#!/usr/bin/env python3
"""
Expert Preference Vector Analyzer

This tool reads predispatch binary files generated by predispatch.py and computes
expert preference vectors for each sample, then outputs the results to a text file.

The expert preference vector represents the frequency distribution of expert usage
for each sample, normalized by the total number of tokens in the sample.
"""

import argparse
import sys
import time
import logging
from pathlib import Path
from typing import List, Tuple, Dict
import numpy as np

# Add project paths
_THIS_FILE = Path(__file__).resolve()
_PROJECT_ROOT = _THIS_FILE.parents[2]  # .../general_router
_MEGATRON_ROOT = _PROJECT_ROOT / "Megatron-LM"
if str(_PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(_PROJECT_ROOT))
if str(_MEGATRON_ROOT) not in sys.path:
    sys.path.insert(0, str(_MEGATRON_ROOT))

from megatron.core.datasets import indexed_dataset

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(sys.stdout),
        logging.FileHandler('expert_preference_analysis.log')
    ]
)
logger = logging.getLogger(__name__)


class ExpertPreferenceAnalyzer:
    """Expert preference vector analyzer"""
    
    def __init__(self, num_experts: int, key: str = 'text'):
        """
        Initialize analyzer
        
        Args:
            num_experts: Total number of experts in the MoE model
            key: Data key name (default: 'text')
        """
        self.num_experts = num_experts
        self.key = key
        logger.info(f"Initialized analyzer with {num_experts} experts, key: {key}")
    
    def read_dispatch_data(self, predispatch_prefix: str) -> Tuple[List[np.ndarray], List[int]]:
        """
        Read dispatch data from binary files
        
        Args:
            predispatch_prefix: Prefix path to predispatch files (without .bin/.idx)
            
        Returns:
            dispatch_sequences: List of dispatch sequences for each sample
            sequence_lengths: List of sequence lengths for each sample
        """
        predispatch_path = Path(predispatch_prefix)
        
        # Construct file paths
        bin_file = predispatch_path.parent / f"{predispatch_path.name}_{self.key}_dispatch_ids.bin"
        idx_file = predispatch_path.parent / f"{predispatch_path.name}_{self.key}_dispatch_ids.idx"
        
        if not bin_file.exists() or not idx_file.exists():
            raise FileNotFoundError(f"Dispatch files do not exist: {bin_file} or {idx_file}")
        
        logger.info(f"Reading dispatch data from: {bin_file}")
        
        # Use Megatron's IndexedDataset to read
        dataset = indexed_dataset.IndexedDataset(str(bin_file)[:-4])  # Remove .bin suffix
        
        dispatch_sequences = []
        sequence_lengths = []
        
        total_samples = len(dataset)
        logger.info(f"Total samples to process: {total_samples}")
        
        for i in range(total_samples):
            doc_indices = dataset.document_indices
            start_seq = int(doc_indices[i])
            end_seq = int(doc_indices[i + 1])
            sequences = dataset[start_seq:end_seq]
            
            # Each document now has only one sequence
            for seq in sequences:
                arr = np.array(seq, dtype=np.uint8)
                dispatch_sequences.append(arr)
                sequence_lengths.append(len(arr))
            
            # Progress logging
            if (i + 1) % 1000 == 0:
                logger.info(f"Processed {i + 1}/{total_samples} samples")
        
        logger.info(f"Successfully read {len(dispatch_sequences)} dispatch sequences")
        return dispatch_sequences, sequence_lengths
    
    def compute_expert_preference_vector(self, dispatch_sequence: np.ndarray) -> np.ndarray:
        """
        Compute expert preference vector for a single sample
        
        Args:
            dispatch_sequence: Dispatch sequence (expert IDs for each token)
            
        Returns:
            Expert preference vector (normalized frequency distribution)
        """
        # Count usage frequency of each expert
        expert_counts = np.bincount(dispatch_sequence, minlength=self.num_experts)
        
        # Normalize to get probability distribution
        total_tokens = len(dispatch_sequence)
        if total_tokens > 0:
            preference_vector = expert_counts.astype(np.float32) / total_tokens
        else:
            preference_vector = np.zeros(self.num_experts, dtype=np.float32)
        
        return preference_vector
    
    def analyze_samples(self, predispatch_prefix: str) -> List[Tuple[int, np.ndarray, int]]:
        """
        Analyze all samples and compute expert preference vectors
        
        Args:
            predispatch_prefix: Prefix path to predispatch files
            
        Returns:
            List of (sample_id, preference_vector, sequence_length) tuples
        """
        logger.info("Starting sample analysis...")
        start_time = time.time()
        
        # Read dispatch data
        dispatch_sequences, sequence_lengths = self.read_dispatch_data(predispatch_prefix)
        
        # Compute preference vectors
        results = []
        total_samples = len(dispatch_sequences)
        
        logger.info(f"Computing preference vectors for {total_samples} samples...")
        
        for i, (dispatch_seq, seq_len) in enumerate(zip(dispatch_sequences, sequence_lengths)):
            preference_vector = self.compute_expert_preference_vector(dispatch_seq)
            results.append((i, preference_vector, seq_len))
            
            # Progress logging
            if (i + 1) % 1000 == 0:
                elapsed = time.time() - start_time
                rate = (i + 1) / elapsed
                logger.info(f"Processed {i + 1}/{total_samples} samples "
                          f"({rate:.1f} samples/sec)")
        
        elapsed = time.time() - start_time
        logger.info(f"Analysis completed in {elapsed:.2f} seconds")
        logger.info(f"Average processing rate: {total_samples/elapsed:.1f} samples/sec")
        
        return results
    
    def save_results_to_txt(self, results: List[Tuple[int, np.ndarray, int]], 
                           output_path: str, include_stats: bool = True):
        """
        Save results to text file
        
        Args:
            results: List of (sample_id, preference_vector, sequence_length) tuples
            output_path: Output file path
            include_stats: Whether to include statistics in the output
        """
        output_path = Path(output_path)
        output_path.parent.mkdir(parents=True, exist_ok=True)
        
        logger.info(f"Saving results to: {output_path}")
        
        with open(output_path, 'w', encoding='utf-8') as f:
            # Write header
            f.write("# Expert Preference Vector Analysis Results\n")
            f.write(f"# Total samples: {len(results)}\n")
            f.write(f"# Number of experts: {self.num_experts}\n")
            f.write(f"# Format: sample_id sequence_length expert_0_freq expert_1_freq ... expert_{self.num_experts-1}_freq\n")
            f.write("#\n")
            
            if include_stats:
                # Compute and write statistics
                all_vectors = np.array([result[1] for result in results])
                all_lengths = np.array([result[2] for result in results])
                
                f.write("# Statistics:\n")
                f.write(f"# Average sequence length: {np.mean(all_lengths):.2f}\n")
                f.write(f"# Min sequence length: {np.min(all_lengths)}\n")
                f.write(f"# Max sequence length: {np.max(all_lengths)}\n")
                f.write(f"# Average expert usage entropy: {np.mean([-np.sum(v[v>0] * np.log2(v[v>0])) for v in all_vectors]):.4f}\n")
                f.write("#\n")
            
            # Write data
            for sample_id, preference_vector, seq_len in results:
                # Format: sample_id sequence_length expert_0_freq expert_1_freq ...
                vector_str = ' '.join([f"{freq:.6f}" for freq in preference_vector])
                f.write(f"{sample_id} {seq_len} {vector_str}\n")
        
        logger.info(f"Results saved to: {output_path}")
        
        # Log summary statistics
        if include_stats:
            all_vectors = np.array([result[1] for result in results])
            all_lengths = np.array([result[2] for result in results])
            
            logger.info("Summary Statistics:")
            logger.info(f"  Total samples: {len(results)}")
            logger.info(f"  Average sequence length: {np.mean(all_lengths):.2f}")
            logger.info(f"  Sequence length range: {np.min(all_lengths)} - {np.max(all_lengths)}")
            logger.info(f"  Average expert usage entropy: {np.mean([-np.sum(v[v>0] * np.log2(v[v>0])) for v in all_vectors]):.4f}")
            
            # Expert usage statistics
            expert_usage = np.mean(all_vectors, axis=0)
            most_used_expert = np.argmax(expert_usage)
            least_used_expert = np.argmin(expert_usage)
            logger.info(f"  Most used expert: {most_used_expert} (avg freq: {expert_usage[most_used_expert]:.4f})")
            logger.info(f"  Least used expert: {least_used_expert} (avg freq: {expert_usage[least_used_expert]:.4f})")


def parse_args():
    """Parse command line arguments"""
    parser = argparse.ArgumentParser(
        description='Analyze expert preference vectors from predispatch results',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    
    parser.add_argument(
        '--predispatch_prefix', 
        required=True, 
        type=str,
        help='Prefix path to predispatch files (without .bin/.idx extension)'
    )
    
    parser.add_argument(
        '--output_path', 
        required=True, 
        type=str,
        help='Output text file path for results'
    )
    
    parser.add_argument(
        '--num_experts', 
        required=True, 
        type=int,
        help='Total number of experts in the MoE model'
    )
    
    parser.add_argument(
        '--key', 
        default='text', 
        type=str,
        help='Data key name (default: text)'
    )
    
    parser.add_argument(
        '--no_stats', 
        action='store_true',
        help='Skip writing statistics to output file'
    )
    
    parser.add_argument(
        '--log_level', 
        default='INFO', 
        choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'],
        help='Logging level'
    )
    
    return parser.parse_args()


def main():
    """Main function"""
    args = parse_args()
    
    # Set logging level
    logging.getLogger().setLevel(getattr(logging, args.log_level))
    
    logger.info("Starting Expert Preference Vector Analysis")
    logger.info(f"Predispatch prefix: {args.predispatch_prefix}")
    logger.info(f"Output path: {args.output_path}")
    logger.info(f"Number of experts: {args.num_experts}")
    logger.info(f"Data key: {args.key}")
    
    try:
        # Initialize analyzer
        analyzer = ExpertPreferenceAnalyzer(
            num_experts=args.num_experts,
            key=args.key
        )
        
        # Analyze samples
        results = analyzer.analyze_samples(args.predispatch_prefix)
        
        # Save results
        analyzer.save_results_to_txt(
            results, 
            args.output_path, 
            include_stats=not args.no_stats
        )
        
        logger.info("Analysis completed successfully!")
        
    except Exception as e:
        logger.error(f"Analysis failed: {e}")
        sys.exit(1)


if __name__ == '__main__':
    main()
