from abc import ABC, abstractmethod
from typing import Dict, List, Any, Optional
import random

class ActivationDataset(ABC):
    """
    Abstract base class for MoE activation datasets.
    
    Represents multiple forward passes through the model.
    Each forward pass contains a sequence of layers with their types.
    """
    
    def __init__(self, num_requests: int):
        """
        Initialize activation dataset.
        
        Args:
            num_requests: Number of forward passes (requests) in the dataset
        """
        self.num_requests = num_requests
        
        # Generate forward pass sequences
        self.forward_passes = self._generate_forward_passes()
    
    @abstractmethod
    def _generate_forward_passes(self) -> Dict[int, List[Dict[str, Any]]]:
        """
        Generate forward pass sequences for all requests.
        
        Returns:
            Dictionary mapping request_id to list of layer info dicts with 'layer_id' and 'layer_type'
        """
        pass
    
    def get_activations(self, request_id: int) -> List[Dict[str, Any]]:
        """
        Get forward pass sequence for a specific request.
        
        Args:
            request_id: ID of the request
            
        Returns:
            List of layer info dictionaries with 'layer_id' and 'layer_type' keys
        """
        return self.forward_passes.get(request_id, [])
    
    def __iter__(self):
        """Allow iteration over all forward passes."""
        iterator_size = len(self)
        for request_id in range(iterator_size):
            yield self.get_activations(request_id)
    
    def __len__(self):
        """Return number of forward passes."""
        return min(self.num_requests, len(self.forward_passes))


class DeterministicActivations(ActivationDataset):
    """
    Deterministic forward pass dataset with alternating attention and MoE layers.
    """
    
    def __init__(self, num_requests: int, num_layers: int = 32, experts_per_layer: int = 32, experts_per_request: int = 4):
        """
        Initialize deterministic forward passes.
        
        Args:
            num_requests: Number of forward passes to generate
            num_layers: Number of layers per forward pass
        """
        self.num_layers = num_layers
        self.experts_per_layer = experts_per_layer
        self.experts_per_request = experts_per_request
        super().__init__(num_requests)
    
    def _generate_forward_passes(self) -> Dict[int, List[Dict[str, Any]]]:
        """Generate deterministic forward pass sequences."""
        forward_passes = {}
        
        current_offset = 0
        
        for request_id in range(self.num_requests):
            forward_pass = []
            
            activated_experts = []
            for i in range(self.experts_per_request):
                activated_experts.append((current_offset + i) % self.experts_per_layer)
            current_offset = (current_offset + self.experts_per_request) % self.experts_per_layer
            
            for layer_id in range(self.num_layers):
                # Alternate between attention and MoE layers
                layer_type = 'attention' if layer_id % 2 == 0 else 'moe'
                
                if layer_type == 'attention':
                    layer_info = {
                        'layer_id': layer_id,
                        'layer_type': layer_type
                    }
                elif layer_type == 'moe':
                    layer_info = {
                        'layer_id': layer_id,
                        'layer_type': layer_type,
                        'activated_experts': activated_experts
                    }
                
                forward_pass.append(layer_info)
            
            forward_passes[request_id] = forward_pass
        
        return forward_passes


class NoisyDeterministicActivations(ActivationDataset):
    """
    Mixed forward pass dataset with variable attention/MoE layer ratios.
    """
    
    def __init__(self, num_requests: int, num_layers: int = 32, 
                 random_seed: Optional[int] = 42,
                 experts_per_layer: int = 32, experts_per_request: int = 4,
                 noise_probability: float = 0.2):
        """
        Initialize noisy forward passes with mixed layer types.
        
        Args:
            num_requests: Number of forward passes to generate
            num_layers: Number of layers per forward pass
            random_seed: Random seed for reproducibility
        """
        self.num_layers = num_layers
        self.random_seed = random_seed
        self.experts_per_layer = experts_per_layer
        self.experts_per_request = experts_per_request
        self.noise_probability = noise_probability
        
        super().__init__(num_requests)
    
    def _generate_forward_passes(self) -> Dict[int, List[Dict[str, Any]]]:
        """Generate mixed forward pass sequences."""
        forward_passes = {}
        
        current_offset = 0
        
        for request_id in range(self.num_requests):
            forward_pass = []
            
            activated_experts = []
            for i in range(self.experts_per_request):
                activated_experts.append((current_offset + i) % self.experts_per_layer)
            current_offset = (current_offset + self.experts_per_request) % self.experts_per_layer
            
            for layer_id in range(self.num_layers):
                # Alternate between attention and MoE layers
                layer_type = 'attention' if layer_id % 2 == 0 else 'moe'
                
                if layer_type == 'attention':
                    layer_info = {
                        'layer_id': layer_id,
                        'layer_type': layer_type
                    }
                elif layer_type == 'moe':
                    layer_noised_experts = []
                    for expert in activated_experts:
                        if random.random() < self.noise_probability:
                            random_expert = random.randint(0, self.experts_per_layer - 1)
                            layer_noised_experts.append(random_expert)
                        else:
                            layer_noised_experts.append(expert)
                    layer_info = {
                        'layer_id': layer_id,
                        'layer_type': layer_type,
                        'activated_experts': layer_noised_experts
                    }
                
                forward_pass.append(layer_info)
            
            forward_passes[request_id] = forward_pass
        
        return forward_passes


class DeepSeekBasedActivations(ActivationDataset):
    """
    Activation dataset that replays exact DeepSeek forward passes from collected routing data.
    """
    
    def __init__(self, path, substitute_start_value=0, num_requests: int = int(2e9), experts_per_request: int = 6):
        """
        Initialize DeepSeek-based forward passes.
        
        Args:
            num_requests: Number of forward passes to generate, default is 2e9 to take all the data
            experts_per_request: Number of experts per MoE layer (top-k)
        """
        self.experts_per_request = experts_per_request
        self.substitute_start_value = substitute_start_value
        self.routing_data = []
        self.path = path
        self._load_routing_data(path)
        super().__init__(num_requests)
    
    def _load_routing_data(self, filename):
        """Load raw routing data from file."""
        import torch
        import os
        
        if os.path.exists(filename):
            self.routing_data = torch.load(filename)
        else:
            raise FileNotFoundError(f"Could not find raw routing data file. Tried: {filename}")
    
    def _generate_forward_passes(self) -> Dict[int, List[Dict[str, Any]]]:
        """Generate forward pass sequences using exact routing data."""
        forward_passes = {}
        
        forward_pass_batch = []
        
        next_seq_idx = 0
        
        curr_layer_idx = None
        
        for routing_batch in self.routing_data:
            # print(routing_batch)
            if curr_layer_idx is None or routing_batch['layer_idx'] < curr_layer_idx:
                if curr_layer_idx is not None:
                    # print("forward_pass_batch: ", forward_pass_batch)
                    forward_passes[next_seq_idx] = forward_pass_batch
                    next_seq_idx += 1
                
                forward_pass_batch = []
            
            curr_layer_idx = routing_batch['layer_idx']
                
            forward_pass_batch.append({
                'layer_id': 2*(routing_batch['layer_idx']-self.substitute_start_value),
                'layer_type': 'attention',
            })
            
            forward_pass_batch.append({
                'layer_id': 2*(routing_batch['layer_idx']-self.substitute_start_value) + 1,
                'layer_type': 'moe',
                'activated_experts': routing_batch['topk_indices'].flatten().tolist(),
                'top_k': routing_batch['top_k']
            })
            
            # print("LEEEENL", len(forward_pass_batch))

        return forward_passes
