
import torch
import torch.nn as nn
import torch.nn.functional as F
from src.models.jcgel.encoders.encoder_imbalance import JCG_Encoder3C_Imbalance
class JCG_Imbalance_CLS(nn.Module):
    def __init__(self, config):
        super().__init__()

        # encoder, decoder = CE_EN_DE[config.dataset]
        self.encoder = JCG_Encoder3C_Imbalance(config)
        # self.img_size = config.img_size
        # self.rotations = config.rotations
        self.linear = nn.Linear(256, config.latent_dim)
        self.num_classes = 10
        self.cls = nn.Linear(config.latent_dim, self.num_classes)
    def forward(self, x: torch.Tensor, label: torch.Tensor = None, loss_fn=F.cross_entropy) -> dict:
        outputs = {}

        z, hidden_states = self.encoder(x)
        last_hidden_states = hidden_states[-1]

        y_hat = self.cls(self.linear(z.contiguous().view(z.size(0), -1)))
        # pdb.set_trace()
        pred = torch.argmax(y_hat, dim=1)

        loss = loss_fn(y_hat, label) if label is not None else None
        outputs["loss"] = loss
        outputs["y_hat"] = y_hat
        outputs["pred"] = pred
        return outputs

    def init_weights(self):
        for n, p in self.named_parameters():
            if p.data.ndimension() >= 2:
                nn.init.xavier_uniform_(p.data)
            # else: