import yaml
import lightning as L
import torch
import torch.nn as nn
import torch.nn.functional as F

from typing import List
from torchvision.models import (
    convnext_tiny, ConvNeXt_Tiny_Weights,
    resnet18, ResNet18_Weights,
    densenet121, DenseNet121_Weights
)

from aion.tokenizers import load_tokenizer
from aion.fourm.fm_utils import NormCrossAttention
from aion.model import AION

__all__ = ["AIONLinearProbing", "AIONCrossAttentionProbing"]

# Defines the task we are trying to solve
class PROVABGSModel(L.LightningModule):
    """This is the base model class for estimating galaxy properties 
    Note that it does not contain the model architecture itself"""
    
    def __init__(self, n_outputs, lr: float = 5e-3, milestone_interval: int = 1, gamma=0.75):
        super().__init__()
        self.save_hyperparameters()
        self.n_outputs = n_outputs
        self.milestones = [milestone_interval * i for i in range(1, 6)]

    def forward(self, x):
        raise NotImplementedError
                
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.mse_loss(y_hat, y).mean()
        self.log('train_loss', loss, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.mse_loss(y_hat, y).mean()
        self.log('val_loss', loss, on_epoch=True, prog_bar=True)
        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.hparams.lr)
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=self.milestones, gamma=self.hparams.gamma)                
        return [optimizer], [scheduler]


class AIONLinearProbing(PROVABGSModel):
    def __init__(self,
                 n_outputs: int,
                 model_path: str,
                 num_encoder_tokens: int = 576,
                 lr: float = 5e-3,
                 milestone_interval: int = 1,
                 gamma: float = 0.75):
        super().__init__(n_outputs, lr, milestone_interval, gamma)
        self.save_hyperparameters()
        self.model_path = model_path
        self.num_encoder_tokens = num_encoder_tokens
        self.aion = AION.from_pretrained(self.model_path)
        self.aion.freeze_encoder()
        self.aion.freeze_decoder()
        self.aion = torch.compile(self.aion)
        self.fc = nn.Linear(self.aion.dim, self.n_outputs)

    def forward(self, x):
        with torch.no_grad():
            embeddings = self.aion.encode(x, num_encoder_tokens=self.num_encoder_tokens)
        embedding = torch.mean(embeddings, dim=1)
        return self.fc(embedding)
    
    
class AIONCrossAttentionProbing(PROVABGSModel):
    def __init__(self,
                 n_outputs: int,
                 num_heads: int,
                 model_path: str,
                 num_encoder_tokens: int = 576,
                 lr: float = 5e-3,
                 milestone_interval: int = 1,
                 gamma: float = 0.75):
        super().__init__(n_outputs, lr, milestone_interval, gamma)
        self.save_hyperparameters()
        self.model_path = model_path
        self.num_heads = num_heads
        self.num_encoder_tokens = num_encoder_tokens
        self.aion = AION.from_pretrained(self.model_path)
        self.aion.freeze_encoder()
        self.aion.freeze_decoder()
        self.aion = torch.compile(self.aion)
        self.dim = self.aion.dim
        self.query = nn.Parameter(torch.randn(1, n_outputs, self.dim))
        self.attention = torch.compile(NormCrossAttention(self.dim, num_heads=self.num_heads, proj_bias=False))
        self.debeds = nn.ModuleList([
            nn.Linear(self.dim, 1) for _ in range(self.n_outputs)
        ])

    def forward(self, x):
        with torch.no_grad():
            embeddings = self.aion.encode(x, num_encoder_tokens=self.num_encoder_tokens)
        # Apply cross-attention 
        query = self.query.expand(embeddings.size(0), -1, -1)
        out = self.attention(query, embeddings)
        # Apply linear layers
        out = torch.cat([debed(out[:,i]) for i, debed in enumerate(self.debeds)], dim=-1)
        return out


class AIONFinetuning(PROVABGSModel):
    def __init__(self,
                 model_path: str,
                 reused_modalities=['tok_z', 'tok_a_i'],
                 num_encoder_tokens: int = 576,
                 finetuning_strategy: str = 'decoder_embeddings',
                 lr: float = 5e-3,
                 milestone_interval: int = 1,
                 gamma: float = 0.75):
        n_outputs = len(reused_modalities)
        super().__init__(n_outputs, lr, milestone_interval, gamma)
        self.save_hyperparameters()
        self.model_path = model_path
        self.reused_modalities = reused_modalities
        self.num_encoder_tokens = num_encoder_tokens
        finetuning_strategy = finetuning_strategy
        self.finetuning_strategy = finetuning_strategy

        self.aion = AION.from_pretrained(self.model_path)
        self.dim = self.aion.dim
        
        if finetuning_strategy == 'nothing':
            self.aion.freeze_encoder()
            self.aion.freeze_decoder()
            self.aion.mask_token.requires_grad = False
            for param in self.aion.decoder_proj_context.parameters():
                param.requires_grad = False

        elif finetuning_strategy == 'decoder_embeddings':
            # Freezing everything on the encoder side
            self.aion.mask_token.requires_grad = False
            for param in self.aion.decoder_proj_context.parameters():
                param.requires_grad = False

            # Freeze everything in the decoder, except the modalities we want to adapt
            # provided in reused_modalities
            embeddings_to_freeze = self.aion.decoder_embeddings.named_parameters()
            embeddings_to_freeze = [name.split('.')[0] for name, _ in embeddings_to_freeze]
            embeddings_to_freeze = list(set(embeddings_to_freeze))
            embeddings_to_freeze = [name for name in embeddings_to_freeze if name not in self.reused_modalities]
            embeddings_to_freeze = '-'.join(embeddings_to_freeze)

            # We freeze both decoder and encoder that way, which only leaves
            # enc/dec mod_emb, enc emb, dec emb, dec to_logit
            # as trainable parameters
            self.aion.freeze_decoder_except_specific_embeddings(embeddings_to_freeze)
            self.aion.freeze_encoder_except_specific_embeddings(embeddings_to_freeze)
        elif finetuning_strategy == 'decoder':
            self.aion.mask_token.requires_grad = False
            for param in self.aion.decoder_proj_context.parameters():
                param.requires_grad = False

            # We only adapt some embeddings for the reused_modalitities
            embeddings_to_freeze = self.aion.decoder_embeddings.named_parameters()
            embeddings_to_freeze = [name.split('.')[0] for name, _ in embeddings_to_freeze]
            embeddings_to_freeze = list(set(embeddings_to_freeze))
            embeddings_to_freeze = [name for name in embeddings_to_freeze if name not in self.reused_modalities]
            embeddings_to_freeze = '-'.join(embeddings_to_freeze)
            # We freeze encoder except for the embeddings for these new modalities
            self.aion.freeze_encoder_except_specific_embeddings(embeddings_to_freeze)
        else: 
            raise ValueError(f"Finetuning strategy {finetuning_strategy} not supported")
        
        # Print the names of any parameter in the entire model that might still be trainable
        print("Following parameters are adjusted during finetuning:")
        for name, param in self.aion.named_parameters():
            if param.requires_grad:
                print(name)
        print("----------------------------------------------------")

        self.aion = torch.compile(self.aion)
        
        # For this module, we don't output a probability distribution but the mean of the ouptut distribution for standardized values between -4,4
        self.value_tensors = [
            torch.linspace(-4, 4, self.aion.decoder_embeddings[k].vocab_size).reshape(1, -1) for k in self.reused_modalities 
        ]


    def forward(self, x):
        k = list(x.keys())[0]
        batch_size = x[k].size(0)
        logits = self.aion(x,
                           target_mask ={k: torch.zeros(1,1).expand(batch_size, -1).bool()
                                          for k in self.reused_modalities},
                           num_encoder_tokens=self.num_encoder_tokens)
        # Given  the logits, we compute the mean of the distribution they represent, assuming that the distribution
        # spans a linearly discretized space from -4 to 4
        probs = [F.softmax(logits[k], dim=-1) for k in self.reused_modalities] 
        means = [torch.sum(prob * value_tensor.to(prob.device), dim=-1) for prob, value_tensor in zip(probs, self.value_tensors)]
        return torch.stack(means, dim=-1)
    

class MultiBackbonePROVABGSModel(PROVABGSModel):
    """
    A generic model for estimating galaxy properties (or other tasks),
    using a flexible choice of backbone (ConvNeXt, ResNet, DenseNet, etc.).
    By default, it expects multi-channel images (e.g., 4 or 5 channels) 
    based on the chosen survey.
    """

    def __init__(
        self,
        n_outputs: int,
        backbone: str = "convnext_tiny",
        n_input_channels: int = 4,
        lr: float = 5e-3,
        milestone_interval: int = 1,
        gamma: float = 0.75,
        range_compress: bool = True,
        dropout_rate: float = 0.2,
        survey: str = "legacysurvey",
        tokenizer_path: str = "data/mmoma/outputs/multisurvey/a88h9lef/checkpoints/last.pt",
        pretrained: bool = False
    ):
        super().__init__(n_outputs, lr, milestone_interval, gamma)
        self.save_hyperparameters()

        # 1. Check survey compatibility
        if survey == "legacysurvey":
            assert n_input_channels == 4, "LegacySurvey only supports 4 input channels"
            self.bands = ['DES-G', 'DES-R', 'DES-I', 'DES-Z']
            self.channels = [5, 6, 7, 8]
        elif survey == "hsc":
            assert n_input_channels == 5, "HSC only supports 5 input channels"
            self.bands = ['HSC-G', 'HSC-R', 'HSC-I', 'HSC-Z', 'HSC-Y']
            self.channels = [0, 1, 2, 3, 4]
        else:
            raise ValueError(
                f"Survey {survey} not supported. "
                f"Supported surveys are 'legacysurvey' and 'hsc'."
            )

        # 2. Build the backbone model
        self.model = self._build_backbone(
            backbone, n_input_channels, n_outputs, dropout_rate, pretrained
        )

        # 3. Set up range compression
        if range_compress:
            self.range_compress = lambda x: torch.arcsinh(x)
        else:
            self.range_compress = lambda x: x

    def _build_backbone(
        self,
        backbone: str,
        n_input_channels: int,
        n_outputs: int,
        dropout_rate: float,
        pretrained: bool
    ) -> nn.Module:
        """
        Internal helper to build and modify the chosen backbone architecture.
        Returns the modified backbone ready for forward passes.
        """
        if backbone.lower() == "convnext_tiny":
            # Optionally, specify weights if pretrained:
            #   ConvNeXt_Tiny_Weights.IMAGENET1K_V1
            weights = ConvNeXt_Tiny_Weights.IMAGENET1K_V1 if pretrained else None
            model = convnext_tiny(weights=weights)

            # Change the first conv layer to accept n_input_channels
            old_layer = model.features[0][0]
            model.features[0][0] = nn.Conv2d(
                in_channels=n_input_channels,
                out_channels=old_layer.out_channels,
                kernel_size=old_layer.kernel_size,
                stride=old_layer.stride,
                padding=old_layer.padding,
                bias=(old_layer.bias is not None),
            )

            # Replace the final classification layer
            in_features = model.classifier[2].in_features
            model.classifier[2] = nn.Sequential(
                nn.Dropout(p=dropout_rate),
                nn.Linear(in_features, n_outputs)
            )

        elif backbone.lower() == "resnet18":
            weights = ResNet18_Weights.IMAGENET1K_V1 if pretrained else None
            model = resnet18(weights=weights)

            # Change the first conv layer to accept n_input_channels
            old_layer = model.conv1
            model.conv1 = nn.Conv2d(
                in_channels=n_input_channels,
                out_channels=old_layer.out_channels,
                kernel_size=old_layer.kernel_size,
                stride=old_layer.stride,
                padding=old_layer.padding,
                bias=(old_layer.bias is not None),
            )

            # Replace the final fully-connected layer
            in_features = model.fc.in_features
            model.fc = nn.Sequential(
                nn.Dropout(p=dropout_rate),
                nn.Linear(in_features, n_outputs)
            )

        elif backbone.lower() == "densenet121":
            weights = DenseNet121_Weights.IMAGENET1K_V1 if pretrained else None
            model = densenet121(weights=weights)

            # Change the first conv layer to accept n_input_channels
            # For DenseNet, it is typically model.features.conv0
            old_layer = model.features.conv0
            model.features.conv0 = nn.Conv2d(
                in_channels=n_input_channels,
                out_channels=old_layer.out_channels,
                kernel_size=old_layer.kernel_size,
                stride=old_layer.stride,
                padding=old_layer.padding,
                bias=(old_layer.bias is not None),
            )

            # Replace the final classification layer
            in_features = model.classifier.in_features
            model.classifier = nn.Sequential(
                nn.Dropout(p=dropout_rate),
                nn.Linear(in_features, n_outputs)
            )

        else:
            raise ValueError(
                f"Backbone '{backbone}' not supported. "
                f"Choose from ['convnext_tiny', 'resnet18', 'densenet121']. "
                f"Or add support by extending _build_backbone()."
            )

        return model

    def forward(self, x):
        """
        x is expected to be a dictionary with at least:
          x['image'] -> Tensor of shape (B, n_input_channels, H, W)
        """
        # Apply range compression if needed
        images = self.range_compress(x['image'])

        # Forward pass through the chosen backbone
        return self.model(images)
