"""
GPU Data Reader for Megatron Training Integration

This module provides utilities to read GPU-level binary data files
generated by the expert-sample coordination optimization process.
Uses Megatron IndexedDataset format for compatibility.
"""

import numpy as np
import logging
from pathlib import Path
from typing import Dict, List, Tuple, Optional

from megatron.core.datasets import indexed_dataset

logger = logging.getLogger(__name__)


class GPUDataReader:
    """
    Reader for GPU-level binary data files.
    
    This class provides methods to read tokenized data, dispatch information,
    and expert placement data for specific GPUs during Megatron training.
    """
    
    def __init__(self, gpu_data_dir: str):
        """
        Initialize GPU data reader.
        
        Args:
            gpu_data_dir: Directory containing GPU data files
        """
        self.gpu_data_dir = Path(gpu_data_dir)
        if not self.gpu_data_dir.exists():
            raise FileNotFoundError(f"GPU data directory not found: {gpu_data_dir}")
    
    def read_gpu_tokens(self, node_id: int, gpu_id: int, start_batch: int = 0, end_batch: int = None) -> List[np.ndarray]:
        """
        Read tokenized data for a specific GPU using IndexedDataset.
        Data is accumulated across batches for each node/gpu combination.
        
        Args:
            node_id: Node identifier
            gpu_id: GPU identifier within the node
            start_batch: Starting batch index (inclusive)
            end_batch: Ending batch index (exclusive, None means all)
            
        Returns:
            List of token arrays for each sample
        """
        tokens_prefix = str(self.gpu_data_dir / f'node_{node_id}' / f'gpu_{gpu_id}' / 'tokens')
        
        if not Path(f"{tokens_prefix}.bin").exists():
            logger.warning(f"Tokens file not found: {tokens_prefix}.bin")
            return []
        
        try:
            dataset = indexed_dataset.IndexedDataset(tokens_prefix)
            tokenized_data = []
            
            # Read all sequences from the accumulated dataset
            for i in range(len(dataset)):
                sequence = dataset[i]
                if isinstance(sequence, tuple):
                    sequence = sequence[0]  # Extract tokens from (tokens, mode) tuple
                tokenized_data.append(sequence)
            
            logger.debug(f"Read {len(tokenized_data)} tokenized samples for node_{node_id}/gpu_{gpu_id}")
            return tokenized_data
            
        except Exception as e:
            logger.error(f"Failed to read tokens for node_{node_id}/gpu_{gpu_id}: {e}")
            return []
    
    def read_gpu_dispatch(self, node_id: int, gpu_id: int, start_batch: int = 0, end_batch: int = None) -> List[np.ndarray]:
        """
        Read dispatch data for a specific GPU using IndexedDataset.
        Data is accumulated across batches for each node/gpu combination.
        
        Args:
            node_id: Node identifier
            gpu_id: GPU identifier within the node
            start_batch: Starting batch index (inclusive)
            end_batch: Ending batch index (exclusive, None means all)
            
        Returns:
            List of dispatch arrays for each sample
        """
        dispatch_prefix = str(self.gpu_data_dir / f'node_{node_id}' / f'gpu_{gpu_id}' / 'dispatch')
        
        if not Path(f"{dispatch_prefix}.bin").exists():
            logger.warning(f"Dispatch file not found: {dispatch_prefix}.bin")
            return []
        
        try:
            dataset = indexed_dataset.IndexedDataset(dispatch_prefix)
            dispatch_data = []
            
            # Read all sequences from the accumulated dataset
            for i in range(len(dataset)):
                sequence = np.copy(dataset[i])
                if isinstance(sequence, tuple):
                    sequence = sequence[0]  # Extract dispatch from (dispatch, mode) tuple
                dispatch_data.append(sequence)
            
            logger.debug(f"Read {len(dispatch_data)} dispatch samples for node_{node_id}/gpu_{gpu_id}")
            return dispatch_data
            
        except Exception as e:
            logger.error(f"Failed to read dispatch for node_{node_id}/gpu_{gpu_id}: {e}")
            return []
    
    def read_gpu_experts(self, node_id: int, gpu_id: int) -> List[int]:
        """
        Read expert placement for a specific GPU using IndexedDataset.
        
        Args:
            node_id: Node identifier
            gpu_id: GPU identifier within the node
            
        Returns:
            List of expert IDs assigned to this GPU
        """
        expert_prefix = str(self.gpu_data_dir / f'node_{node_id}' / f'gpu_{gpu_id}' / 'experts')
        
        if not Path(f"{expert_prefix}.bin").exists():
            logger.warning(f"Expert file not found: {expert_prefix}.bin")
            return []
        
        try:
            dataset = indexed_dataset.IndexedDataset(expert_prefix)
            experts= []
            
            for i in range(len(dataset)):
                # Expert data 
                expert_sequence = np.copy(dataset[i])
                if isinstance(expert_sequence, tuple):
                    expert_sequence = expert_sequence[0]  # Extract experts from (experts, mode) tuple
                experts.append(expert_sequence)
            
            logger.debug(f"Read {len(experts)} experts for node_{node_id}/gpu_{gpu_id}: {experts}")
            return experts
            
        except Exception as e:
            logger.error(f"Failed to read experts for node_{node_id}/gpu_{gpu_id}: {e}")
            return []

    def read_gpu_labels(self, node_id: int, gpu_id: int) -> List[np.ndarray]:
        """
        Read labels data for a specific GPU using IndexedDataset.
        
        Args:
            node_id: Node identifier
            gpu_id: GPU identifier within the node
            
        Returns:
            List of label arrays for each sample
        """
        labels_prefix = str(self.gpu_data_dir / f'node_{node_id}' / f'gpu_{gpu_id}' / 'labels')
        
        if not Path(f"{labels_prefix}.bin").exists():
            logger.warning(f"Labels file not found: {labels_prefix}.bin")
            return []
        
        try:
            dataset = indexed_dataset.IndexedDataset(labels_prefix)
            labels_data = []
            
            for i in range(len(dataset)):
                labels = np.copy(dataset[i])
                if isinstance(labels, tuple):
                    labels = labels[0]  # Extract labels from (labels, mode) tuple
                labels_data.append(labels)
            
            logger.debug(f"Read {len(labels_data)} label samples for node_{node_id}/gpu_{gpu_id}")
            return labels_data
            
        except Exception as e:
            logger.error(f"Failed to read labels for node_{node_id}/gpu_{gpu_id}: {e}")
            return []

    def read_gpu_loss_mask(self, node_id: int, gpu_id: int) -> List[np.ndarray]:
        """
        Read loss mask data for a specific GPU using IndexedDataset.
        
        Args:
            node_id: Node identifier
            gpu_id: GPU identifier within the node
            
        Returns:
            List of loss mask arrays for each sample
        """
        loss_mask_prefix = str(self.gpu_data_dir / f'node_{node_id}' / f'gpu_{gpu_id}' / 'loss_mask')
        
        if not Path(f"{loss_mask_prefix}.bin").exists():
            logger.warning(f"Loss mask file not found: {loss_mask_prefix}.bin")
            return []
        
        try:
            dataset = indexed_dataset.IndexedDataset(loss_mask_prefix)
            loss_mask_data = []
            
            for i in range(len(dataset)):
                loss_mask = np.copy(dataset[i])
                if isinstance(loss_mask, tuple):
                    loss_mask = loss_mask[0]  # Extract loss mask from (loss_mask, mode) tuple
                loss_mask_data.append(loss_mask)
            
            logger.debug(f"Read {len(loss_mask_data)} loss mask samples for node_{node_id}/gpu_{gpu_id}")
            return loss_mask_data
            
        except Exception as e:
            logger.error(f"Failed to read loss mask for node_{node_id}/gpu_{gpu_id}: {e}")
            return []

    def read_gpu_position_ids(self, node_id: int, gpu_id: int) -> List[np.ndarray]:
        """
        Read position IDs data for a specific GPU using IndexedDataset.
        
        Args:
            node_id: Node identifier
            gpu_id: GPU identifier within the node
            
        Returns:
            List of position ID arrays for each sample
        """
        position_ids_prefix = str(self.gpu_data_dir / f'node_{node_id}' / f'gpu_{gpu_id}' / 'position_ids')
        
        if not Path(f"{position_ids_prefix}.bin").exists():
            logger.warning(f"Position IDs file not found: {position_ids_prefix}.bin")
            return []
        
        try:
            dataset = indexed_dataset.IndexedDataset(position_ids_prefix)
            position_ids_data = []
            
            for i in range(len(dataset)):
                position_ids = np.copy(dataset[i])
                if isinstance(position_ids, tuple):
                    position_ids = position_ids[0]  # Extract position IDs from (position_ids, mode) tuple
                position_ids_data.append(position_ids)
            
            logger.debug(f"Read {len(position_ids_data)} position ID samples for node_{node_id}/gpu_{gpu_id}")
            return position_ids_data
            
        except Exception as e:
            logger.error(f"Failed to read position IDs for node_{node_id}/gpu_{gpu_id}: {e}")
            return []

    def read_gpu_attention_mask(self, node_id: int, gpu_id: int) -> List[np.ndarray]:
        """
        Read attention mask data for a specific GPU using IndexedDataset.
        
        Args:
            node_id: Node identifier
            gpu_id: GPU identifier within the node
            
        Returns:
            List of attention mask arrays for each sample
        """
        attention_mask_prefix = str(self.gpu_data_dir / f'node_{node_id}' / f'gpu_{gpu_id}' / 'attention_mask')
        
        if not Path(f"{attention_mask_prefix}.bin").exists():
            logger.warning(f"Attention mask file not found: {attention_mask_prefix}.bin")
            return []
        
        try:
            dataset = indexed_dataset.IndexedDataset(attention_mask_prefix)
            attention_mask_data = []
            
            for i in range(len(dataset)):
                attention_mask = np.copy(dataset[i])
                if isinstance(attention_mask, tuple):
                    attention_mask = attention_mask[0]  # Extract attention mask from (attention_mask, mode) tuple
                attention_mask_data.append(attention_mask)
            
            logger.debug(f"Read {len(attention_mask_data)} attention mask samples for node_{node_id}/gpu_{gpu_id}")
            return attention_mask_data
            
        except Exception as e:
            logger.error(f"Failed to read attention mask for node_{node_id}/gpu_{gpu_id}: {e}")
            return []
    
    def read_gpu_data(self, node_id: int, gpu_id: int, start_batch: int = 0, end_batch: int = None) -> Dict[str, any]:
        """
        Read all data for a specific GPU across batches.
        
        Args:
            node_id: Node identifier
            gpu_id: GPU identifier within the node
            start_batch: Starting batch index (inclusive)
            end_batch: Ending batch index (exclusive, None means all)
            
        Returns:
            Dictionary containing tokens, dispatch, experts, and Megatron processed data
        """
        tokens = self.read_gpu_tokens(node_id, gpu_id, start_batch, end_batch)
        dispatch = self.read_gpu_dispatch(node_id, gpu_id, start_batch, end_batch)
        experts = self.read_gpu_experts(node_id, gpu_id)
        labels = self.read_gpu_labels(node_id, gpu_id)
        loss_mask = self.read_gpu_loss_mask(node_id, gpu_id)
        position_ids = self.read_gpu_position_ids(node_id, gpu_id)
        attention_mask = self.read_gpu_attention_mask(node_id, gpu_id)
        
        return {
            'tokens': tokens,
            'dispatch': dispatch,
            'experts': experts,
            'labels': labels,
            'loss_mask': loss_mask,
            'position_ids': position_ids,
            'attention_mask': attention_mask,
            'num_samples': len(tokens)
        }
    
    def list_available_gpus(self) -> List[Tuple[int, int]]:
        """
        List all available (node_id, gpu_id) pairs in the GPU data directory.
        
        Returns:
            List of (node_id, gpu_id) tuples
        """
        gpu_pairs = []
        
        for node_dir in self.gpu_data_dir.iterdir():
            if node_dir.is_dir() and node_dir.name.startswith('node_'):
                try:
                    node_id = int(node_dir.name.split('_')[1])
                    
                    for gpu_dir in node_dir.iterdir():
                        if gpu_dir.is_dir() and gpu_dir.name.startswith('gpu_'):
                            try:
                                gpu_id = int(gpu_dir.name.split('_')[1])
                                gpu_pairs.append((node_id, gpu_id))
                            except (ValueError, IndexError):
                                logger.warning(f"Invalid GPU directory name: {gpu_dir.name}")
                                
                except (ValueError, IndexError):
                    logger.warning(f"Invalid node directory name: {node_dir.name}")
        
        return sorted(gpu_pairs)
    

    
    def validate_gpu_data(self, node_id: int, gpu_id: int) -> bool:
        """
        Validate that all required IndexedDataset files exist for a GPU.
        
        Args:
            node_id: Node identifier
            gpu_id: GPU identifier within the node
            
        Returns:
            True if all files exist and are valid
        """
        gpu_dir = self.gpu_data_dir / f'node_{node_id}' / f'gpu_{gpu_id}'
        
        required_files = [
            ('tokens.bin', 'tokens.idx'),
            ('dispatch.bin', 'dispatch.idx'),
            ('experts.bin', 'experts.idx'),
            ('labels.bin', 'labels.idx'),
            ('loss_mask.bin', 'loss_mask.idx'),
            ('position_ids.bin', 'position_ids.idx'),
            ('attention_mask.bin', 'attention_mask.idx')
        ]
        
        for bin_file, idx_file in required_files:
            bin_path = gpu_dir / bin_file
            idx_path = gpu_dir / idx_file
            if not bin_path.exists() or not idx_path.exists():
                logger.error(f"Required IndexedDataset files missing: {bin_path} or {idx_path}")
                return False
        
        return True


class MegatronGPUDataLoader:
    """
    Megatron-compatible GPU data loader.
    
    This class provides a Megatron-compatible interface for loading
    GPU-level data during training.
    """
    
    def __init__(self, gpu_data_dir: str, node_id: int, gpu_id: int):
        """
        Initialize Megatron GPU data loader.
        
        Args:
            gpu_data_dir: Directory containing GPU data files
            node_id: Node identifier for this loader
            gpu_id: GPU identifier for this loader
        """
        self.reader = GPUDataReader(gpu_data_dir)
        self.node_id = node_id
        self.gpu_id = gpu_id
        self.current_batch_id = 0
    
    def load_data(self, start_batch: int = 0, end_batch: int = None) -> Dict[str, any]:
        """
        Load data for a range of batches.
        
        Args:
            start_batch: Starting batch index (inclusive)
            end_batch: Ending batch index (exclusive, None means all)
            
        Returns:
            Dictionary containing data
        """
        return self.reader.read_gpu_data(self.node_id, self.gpu_id, start_batch, end_batch)
    
    def get_all_data(self) -> Dict[str, any]:
        """
        Get all available data for this GPU.
        
        Returns:
            Dictionary containing all data
        """
        return self.load_data()
    
    def reset(self):
        """Reset the loader (no-op for accumulated data)."""
        pass
