from typing import List
import lightning as L
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import efficientnet_b0
import numpy as np
from aion.fourm.fm_utils import NormCrossAttention
from aion.model import AION

from transformers import AutoConfig, ConvNextV2ForImageClassification

__all__ = ["AIONMeanPoolClassifier", "AIONCrossAttentionClassifier", "ConvNextNanoClassifier", "EfficientNetB0"]

# Defines the task we are trying to solve
class GZ10Model(L.LightningModule):
    """This is the base model class for estimating galaxy properties 
    Note that it does not contain the model architecture itself"""
    
    def __init__(self, lr: float = 5e-3, max_epochs: int = 100):
        super().__init__()
        self.save_hyperparameters()

    def forward(self, x):
        raise NotImplementedError
                
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y).mean()
        self.log('train_loss', loss, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y).mean()
        preds = torch.argmax(y_hat, dim=1)
        acc = (preds == y).float().mean().item()
        self.log('val_loss', loss, on_epoch=True, prog_bar=True)
        self.log('val_acc', acc, on_epoch=True, prog_bar=True)
        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.hparams.lr)
        scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=[int(self.hparams.max_epochs * 0.2), int(self.hparams.max_epochs * 0.4), int(self.hparams.max_epochs * 0.6), int(self.hparams.max_epochs * 0.8)], gamma=0.5
        )                
        # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[2, 4, 6, 8], gamma=0.5)                
        return [optimizer], [scheduler]


class AIONMeanPoolClassifier(GZ10Model):
    def __init__(self,
                 n_classes: int,
                 model_path: str,
                 num_encoder_tokens: int = 576,
                 lr: float = 5e-3):
        super().__init__(lr)
        self.save_hyperparameters()
        self.n_classes = n_classes
        self.model_path = model_path
        self.num_encoder_tokens = num_encoder_tokens
        self.aion = AION.from_pretrained(self.model_path)
        self.aion.freeze_encoder()
        self.aion.freeze_decoder()
        self.aion = torch.compile(self.aion)
        self.fc = nn.Linear(self.aion.dim, self.n_classes)

    def forward(self, x):
        with torch.no_grad():
            embeddings = self.aion.encode(x, num_encoder_tokens=self.num_encoder_tokens)
        # Mean pool the embeddings
        mean_embeddings = embeddings.mean(dim=1)
        output = self.fc(mean_embeddings)
        return output


class AIONCrossAttentionClassifier(GZ10Model):
    def __init__(self,
                 n_classes: int,
                 num_heads: int,
                 model_path: str,
                 num_encoder_tokens: int = 576,
                 lr: float = 5e-3):
        super().__init__(lr)
        self.save_hyperparameters()
        self.n_classes = n_classes
        self.model_path = model_path
        self.num_encoder_tokens = num_encoder_tokens
        self.aion = AION.from_pretrained(self.model_path)
        self.aion.freeze_encoder()
        self.aion.freeze_decoder()
        self.dim = self.aion.dim
        self.query = nn.Parameter(torch.randn(1, self.dim))
        self.num_heads = num_heads
        self.attention = torch.compile(NormCrossAttention(self.dim, num_heads=self.num_heads, proj_bias=False))
        self.fc = nn.Linear(self.aion.dim, self.n_classes)

    def forward(self, x):
        with torch.no_grad():
            embeddings = self.aion.encode(x, num_encoder_tokens=self.num_encoder_tokens)
        # Apply cross-attention 
        query = self.query.expand(embeddings.size(0), -1, -1) # (batch_size, 1, dim)
        out = self.attention(query, embeddings).squeeze(1) # (batch_size, dim)
        # Apply linear layers
        output = self.fc(out)
        return output


class ConvNextNanoClassifier(GZ10Model):
    def __init__(self,
                 n_classes: int,
                 lr: float = 5e-3):
        super().__init__(lr)
        self.save_hyperparameters()
        self.n_classes = n_classes
        config = AutoConfig.from_pretrained('facebook/convnextv2-nano-22k-224')
        config.num_labels = self.n_classes
        config.num_channels = 4
        self.conv = ConvNextV2ForImageClassification(config=config)

    def forward(self, x):
        x = x['tok_image']
        logits = self.conv(x).logits
        return logits


class EfficientNetB0(GZ10Model):
    def __init__(self,
                 n_classes: int,
                 lr: float = 5e-3,
                 max_epochs: int = 100):
        super().__init__(lr, max_epochs)
        self.save_hyperparameters()
        self.n_classes = n_classes
        self.model = efficientnet_b0(weights=None)
        self.model.features[0][0] = nn.Conv2d(
            4, 32, kernel_size=3, stride=2, padding=1, bias=False
        )
        self.model.classifier[1] = nn.Linear(self.model.classifier[1].in_features, self.n_classes)

    def forward(self, x):
        x = x['tok_image']
        logits = self.model(x)
        return logits
