import torch
import torch.nn as nn
import torch.nn.functional as F
from src.models.jcgel.encoders.jcgel_resnet_encoder import JCGResNet18
NUM_CLASSES = {
    "eurosat": 10,
    "cifar10": 10,
    "cifar100": 100,
    "flowers102": 102,
    "stanfordcars": 196,
    "stl10": 10,
    "food101": 101,
    "caltech101": 101,
    "oxfordiiitpet": 37,
    "aircraft": 100}


class JCGResNet_cls(nn.Module):
    def __init__(self, config):
        super().__init__()

        # encoder, decoder = CE_EN_DE[config.dataset]
        self.num_classes = NUM_CLASSES[config.dataset]
        self.encoder = JCGResNet18(self.num_classes, config=config)

    def forward(self, x: torch.Tensor, label: torch.Tensor = None, loss_fn=F.cross_entropy) -> dict:
        outputs = {}

        y_hat = self.encoder(x)
        # pdb.set_trace()
        pred = torch.argmax(y_hat, dim=1)

        loss = loss_fn(y_hat, label) if label is not None else None
        outputs["loss"] = loss
        outputs["y_hat"] = y_hat
        outputs["pred"] = pred
        return outputs

    def init_weights(self):
        for m in self.encoder.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            if isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def freeze(self):
        for n, p in self.named_parameters():
            if 'encoder.fc' not in n:
                p.requires_grad = False
            else:
                p.requires_grad = True
            # else: