import lightning as L
import torch
import torch.nn as nn
import torch.nn.functional as F
from aion.fourm.fm_utils import NormCrossAttention
from aion.model import AION

from .modules import ConvNextEncoder1d

__all__ = ["AIONLinearProbing", "AIONCrossAttentionProbing"]


# Defines the task we are trying to solve
class RegressorModel(L.LightningModule):
    """This is the base model class for estimating properties
    Note that it does not contain the model architecture itself"""

    def __init__(self, n_outputs, lr: float = 5e-3):
        super().__init__()
        self.save_hyperparameters()
        self.n_outputs = n_outputs

    def forward(self, x):
        raise NotImplementedError

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.mse_loss(y_hat, y).mean()
        self.log("train_loss", loss, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.mse_loss(y_hat, y).mean()
        self.log("val_loss", loss, on_epoch=True, prog_bar=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.hparams.lr)
        return optimizer


class AIONBaselineSpectrumModel(RegressorModel):
    def __init__(
        self,
        depths: list[int] = [3, 9, 3, 3],
        dims: list[int] = [64, 96, 128, 192],
        n_outputs: int = 16,
        lr: float = 5e-3,
    ):
        super().__init__(n_outputs, lr)
        self.save_hyperparameters()
        self.encoder = ConvNextEncoder1d(in_chans=2, depths=depths, dims=dims)
        self.query = nn.Parameter(torch.randn(1, dims[-1]))
        self.proj = nn.Linear(dims[-1], n_outputs)

    def forward(self, x):
        spec = torch.stack(
            [x["desi_spectrum_flux"], x["desi_spectrum_ivar"]], dim=1
        )  # b 2 n
        out = self.encoder(spec).swapaxes(1, 2)  # b n c
        query = self.query.expand(out.size(0), -1, -1)
        out = F.scaled_dot_product_attention(query, out, out).squeeze(1)  # b c
        return self.proj(out)  # b n_outputs


class AIONLinearProbing(RegressorModel):
    def __init__(
        self,
        n_outputs: int,
        model_path: str,
        num_encoder_tokens: int = 576,
        lr: float = 5e-3,
    ):
        super().__init__(n_outputs, lr)
        self.save_hyperparameters()
        self.model_path = model_path
        self.num_encoder_tokens = num_encoder_tokens
        self.aion = AION.from_pretrained(self.model_path)
        self.aion.freeze_encoder()
        self.aion.freeze_decoder()
        self.aion = torch.compile(self.aion)
        self.fc = nn.Linear(self.aion.dim, self.n_outputs)

    def forward(self, x):
        with torch.no_grad():
            embeddings = self.aion.encode(x, num_encoder_tokens=self.num_encoder_tokens)
        embedding = torch.mean(embeddings, dim=1)
        return self.fc(embedding)


class AIONCrossAttentionProbing(RegressorModel):
    def __init__(
        self,
        n_outputs: int,
        num_heads: int,
        model_path: str,
        num_encoder_tokens: int = 576,
        lr: float = 5e-3,
    ):
        super().__init__(n_outputs, lr)
        self.save_hyperparameters()
        self.model_path = model_path
        self.num_heads = num_heads
        self.num_encoder_tokens = num_encoder_tokens
        self.aion = AION.from_pretrained(self.model_path)
        self.aion.freeze_encoder()
        self.aion.freeze_decoder()
        self.aion = torch.compile(self.aion)
        self.dim = self.aion.dim
        self.query = nn.Parameter(torch.randn(1, n_outputs, self.dim))
        self.attention = torch.compile(
            NormCrossAttention(self.dim, num_heads=self.num_heads, proj_bias=False)
        )
        self.debeds = nn.ModuleList(
            [nn.Linear(self.dim, 1) for _ in range(self.n_outputs)]
        )

    def forward(self, x):
        with torch.no_grad():
            embeddings = self.aion.encode(x, num_encoder_tokens=self.num_encoder_tokens)
        # Apply cross-attention
        query = self.query.expand(embeddings.size(0), -1, -1)
        out = self.attention(query, embeddings)
        out = torch.cat(
            [debed(out[:, i]) for i, debed in enumerate(self.debeds)], dim=-1
        )
        return out
