import torch
from typing import Dict, Tuple, Any, Optional

from .fourm.fm import FM
from .fourm.modality_info import MODALITY_INFO

class AION(FM):
    """
    AION-1 (AstronomIcal Omnimodal Network) - A foundation model for astronomical data.
    
    AION is a multimodal transformer-based foundation model specifically designed for 
    astronomical observations. It extends the 4M (Massively Multimodal Masked Modeling) 
    architecture to handle diverse astronomical data types including multi-band images, 
    spectra, and other observational modalities.
    
    The model is pre-trained using generative masked modeling on data from major 
    astronomical surveys including:
    - DES (Dark Energy Survey) 
    - HSC (Hyper Suprime-Cam)
    - DESI (Dark Energy Spectroscopic Instrument)
    - Gaia
    
    Key capabilities:
    - Any-to-any multimodal generation and prediction
    - Cross-survey translation (e.g., DES to HSC)
    - Physical parameter estimation (redshift, morphology, etc.)
    - Spectrum super-resolution
    - Emergent cross-modal understanding
    
    Model sizes:
    - AION-1 Base: 300M parameters
    - AION-1 Large: 800M parameters  
    - AION-1 XLarge: 3B parameters
    """

    def embed_inputs(self, input_dict: Dict[str, torch.Tensor], mask: Optional[Dict[str, torch.Tensor]] = None, 
                                num_encoder_tokens: int = 256
                             ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Embeds astronomical input modalities into encoder tokens using learned tokenizers.
        
        This method processes diverse astronomical data (images, spectra, etc.) through
        modality-specific tokenizers that convert continuous data into discrete tokens
        suitable for transformer processing.

        Args:
            input_dict (Dict[str, torch.Tensor]): Dictionary mapping modality names to data.
                Examples: {'hsc': hsc_image, 'desi': spectrum, 'gaia': gaia_features}
                Each tensor should have shape (B, N_mod) where N_mod depends on the modality.
            mask (Optional[Dict[str, torch.Tensor]]): Optional mask dictionary for input masking.
                Used during pre-training for masked modeling objectives. Defaults to None.
            num_encoder_tokens (int): Maximum number of encoder tokens to select across all 
                modalities. Defaults to 256.

        Returns:
            tuple:
                - encoder_tokens (torch.Tensor): Selected encoder tokens from all modalities. 
                  Shape (B, N, D) where N is the number of selected encoder tokens. 
                - encoder_emb (torch.Tensor): Corresponding embeddings for encoder tokens. 
                  Shape (B, N, D)
                - encoder_mask (torch.Tensor): A boolean mask indicating which encoder tokens 
                  are valid (set to 0 for valid tokens, 1 otherwise). Shape (B, 1, N)
                - mod_mask (torch.Tensor): An integer mask marking the modality type for each 
                  encoder token (with -1 indicating unassigned pad tokens). Shape (B, N)
        """
        if mask is None:
            mask = {}
        assert isinstance(input_dict, dict), "first input must be a dictionary"
        assert isinstance(mask, dict), "Mask must be a dictionary if provided"
        assert all(key in input_dict for key in mask), "All keys in the input mask must be in X"
        assert all(key in self.encoder_embeddings for key in input_dict.keys()), "All keys in X must be in self.encoder_embeddings"

        device = next(self.parameters()).device

        encoder_mod_dict = {}
        for mod, tensor in input_dict.items():
            tensor = tensor.to(torch.long).to(device)
            if tensor.dim() == 1:
                tensor = tensor.unsqueeze(1)
            input_mask = mask.get(mod, torch.zeros(tensor.shape[0], tensor.shape[1], dtype=torch.bool, device=device))
            if MODALITY_INFO[mod]['type']  == 'img':
                assert tensor.shape[1] == self.encoder_embeddings[mod].num_patches, f"Expected size {self.encoder_embeddings[mod].num_patches} for modality {mod}, but got {tensor.shape[1]}"

            encoder_mod_dict[mod] = self.encoder_embeddings[mod]({'tensor': tensor, 'input_mask': input_mask})

        encoder_tokens, encoder_emb, encoder_mask, encoder_mod_mask  = self.forward_mask_encoder(encoder_mod_dict, num_encoder_tokens)

        return encoder_tokens, encoder_emb, encoder_mask, encoder_mod_mask
    
    def embed_targets(self, 
                     target_mask: Dict[str, torch.Tensor],
                     num_decoder_tokens: int = 256) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Prepares decoder tokens for generating or predicting target astronomical modalities.
        
        This enables any-to-any generation capabilities, allowing the model to predict
        any combination of output modalities (e.g., generating spectra from images,
        or translating between different survey observations).
        
        Args:
            target_mask (Dict[str, torch.Tensor]): Dictionary defining which modalities and
                tokens to predict. Keys are modality names, values are boolean masks.
                Example: {'desi': spectrum_mask, 'hsc': image_mask}
            num_decoder_tokens (int): Maximum number of decoder tokens. Defaults to 256.
        
        Returns:
            tuple:
                - decoder_tokens (torch.Tensor): Selected decoder tokens from all modalities. Shape (B, M, D) where M is the number of selected decoder tokens.
                - decoder_emb (torch.Tensor): Corresponding embeddings for decoder tokens. Shape (B, M, D)
                - decoder_mask (torch.Tensor): A boolean mask indicating which decoder tokens are valid (set to 0 for valid tokens, 1 otherwise). Shape (B, 1, M)
                - target_ids (torch.Tensor): IDs of the target tokens corresponding to the decoder tokens. Shape (B, M)
                - decoder_attention_mask (torch.Tensor): Mask for the decoder self-attention layers. Shape (B, M, M)
                - mod_mask (torch.Tensor): An integer mask marking the modality type for each decoder token (with -1 indicating unassigned pad tokens). Shape (B, M)
        """
        assert isinstance(target_mask, dict), "Target mask must be a dictionary"
        assert all(key in self.decoder_embeddings for key in target_mask.keys()), "All keys in target mask must be in self.decoder_embeddings"

        device = next(self.parameters()).device

        decoder_mod_dict = {}
        for mod, mask in target_mask.items():
            mask = mask.to(torch.bool).to(device)
            tensor =torch.zeros_like(mask).to(torch.long).to(device)
            decoder_attention_mask = torch.zeros_like(mask).to(torch.bool).to(device)
            decoder_mod_dict[mod] = self.decoder_embeddings[mod].forward_embed({'tensor': tensor, 
                                                                                'target_mask': mask, 
                                                                                'decoder_attention_mask': decoder_attention_mask})
        
        decoder_tokens, decoder_emb, decoder_mask, target_ids, decoder_attention_mask, decoder_mod_mask = self.forward_mask_decoder(decoder_mod_dict, num_decoder_tokens)

        return decoder_tokens, decoder_emb, decoder_mask, target_ids, decoder_attention_mask, decoder_mod_mask

    def _encode(self, encoder_tokens, encoder_emb, encoder_mask):
        """Run the transformer encoder on embedded astronomical tokens."""
        x = encoder_tokens + encoder_emb
        x = self.forward_encoder(x, encoder_mask=encoder_mask)
        context = self.decoder_proj_context(x) + encoder_emb
        return context
    
    def _decode(self, encoder_outputs, encoder_mask, decoder_tokens, decoder_emb, decoder_attention_mask):
        """Run the transformer decoder to generate target modality predictions."""
        x = decoder_tokens + decoder_emb
        x = self.forward_decoder(x, encoder_outputs, encoder_mask=encoder_mask, decoder_attention_mask=decoder_attention_mask)
        return x

    def encode(self, input_dict: Dict[str, torch.Tensor], 
               input_mask: Optional[Dict[str, torch.Tensor]] = None,
               num_encoder_tokens: int = 256) -> torch.Tensor:
        """
        Encode astronomical observations into learned representations.
        
        This method produces embeddings that capture the essential information from
        input observations. These embeddings can be used for:
        - Downstream tasks via linear probing or fine-tuning
        - Similarity search and retrieval
        - Cross-modal understanding
        
        Example:
            >>> # Encode galaxy images from different surveys
            >>> embeddings = model.encode({
            ...     'des': des_images,  # Dark Energy Survey
            ...     'hsc': hsc_images   # Hyper Suprime-Cam
            ... })

        Args:
            input_dict (Dict[str, torch.Tensor]): Dictionary of input modalities
            input_mask (Optional[Dict[str, torch.Tensor]]): Optional input masking
            num_encoder_tokens (int): Maximum number of encoder tokens. Defaults to 256.
            
        Returns:
            torch.Tensor: Encoded representations of shape (B, N, D)
        """
        encoder_tokens, encoder_emb, encoder_mask, _ = self.embed_inputs(input_dict, mask=input_mask, num_encoder_tokens=num_encoder_tokens)
        return self._encode(encoder_tokens, encoder_emb, encoder_mask)

    def forward(self, 
                input_dict: Dict[str, torch.Tensor], 
                target_mask: Dict[str, torch.Tensor],
                input_mask: Optional[Dict[str, torch.Tensor]] = None,
                num_decoder_tokens: int = 256,
                num_encoder_tokens: int = 256) -> Dict[str, torch.Tensor]:
        """
        Perform any-to-any multimodal prediction on astronomical data.
        
        AION can generate predictions for any combination of target modalities given
        any combination of input modalities. This enables powerful capabilities like:
        - Survey translation (e.g., predict HSC observations from DES)
        - Cross-modal prediction (e.g., predict spectra from images)
        - Physical parameter estimation (e.g., predict redshift from images)
        - Data fusion from multiple instruments
        
        Example:
            >>> # Predict DESI spectra from HSC images
            >>> logits = model.forward(
            ...     input_dict={'hsc': hsc_images},
            ...     target_mask={'desi': spectrum_mask}
            ... )
            >>> 
            >>> # Multi-modal fusion: combine multiple inputs to predict properties
            >>> logits = model.forward(
            ...     input_dict={'des': des_images, 'gaia': gaia_features},
            ...     target_mask={'redshift': z_mask, 'morphology': morph_mask}
            ... )

        Args:
            input_dict (Dict[str, torch.Tensor]): Input modality data
            target_mask (Dict[str, torch.Tensor]): Target modalities to predict
            input_mask (Optional[Dict[str, torch.Tensor]]): Optional input masking
            num_decoder_tokens (int): Maximum decoder tokens. Defaults to 256.
            num_encoder_tokens (int): Maximum encoder tokens. Defaults to 256.

        Returns:
            Dict[str, torch.Tensor]: Dictionary mapping target modality names to their 
                predicted logits. The shape depends on the modality vocabulary.
        """
        # Embedding inputs and targets 
        encoder_tokens, encoder_emb, encoder_mask, _ = self.embed_inputs(input_dict, mask=input_mask, num_encoder_tokens=num_encoder_tokens)
        decoder_tokens, decoder_emb, decoder_mask, target_ids, decoder_attention_mask, decoder_mod_mask  = self.embed_targets(target_mask, num_decoder_tokens=num_decoder_tokens)
        
        # Run the encoder
        encoder_output = self._encode(encoder_tokens, encoder_emb, encoder_mask)
        decoder_output = self._decode(encoder_output, encoder_mask, decoder_tokens, decoder_emb, decoder_attention_mask)
    
        # Now, we compute the logits for the requested tokens and return them
        mod_logits = {}
        for mod in target_mask.keys():
            idx = self.modality_info[mod]["id"]
            mod_logits[mod] = self.decoder_embeddings[mod].forward_logits(decoder_output[decoder_mod_mask == idx])
            
        return mod_logits