import torch
import torch.nn as nn
from transformers import ResNetConfig, ResNetForImageClassification
from transformers.models.resnet.modeling_resnet import ResNetModel, ResNetEncoder
from transformers import ConvNextImageProcessor
from torchvision.transforms import (
    Compose, 
    RandomHorizontalFlip, 
    RandomCrop, 
    ToTensor, 
    Normalize
)
import os
from typing import List

class ResNetCifarEmbeddings(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 3x3 conv, stride 1, padding 1 for CIFAR
        self.embedder = nn.Conv2d(
            config.num_channels, 
            config.embedding_size, 
            kernel_size=3, 
            stride=1, 
            padding=1, 
            bias=False
        )
        self.norm = nn.BatchNorm2d(config.embedding_size)
        self.activation = nn.ReLU()

    def forward(self, pixel_values):
        if pixel_values.dtype == torch.float16 and self.embedder.weight.dtype == torch.float32:
            pixel_values = pixel_values.to(torch.float32)
            
        x = self.embedder(pixel_values)
        x = self.norm(x)
        x = self.activation(x)
        return x

# 2. Define the Custom Model Class
class ResNet20ForCIFAR(ResNetForImageClassification):
    def __init__(self, config):
        super(ResNetForImageClassification, self).__init__(config)
        self.num_labels = config.num_labels
        self.resnet = ResNetModel(config)
        # classification head
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(config.hidden_sizes[-1], config.num_labels) if config.num_labels > 0 else nn.Identity(),
        )

        # Replace the default embedder
        self.resnet.embedder = ResNetCifarEmbeddings(config)

        # initialize weights and apply final processing
        self.post_init()

        self.loss_type = "ForSequenceClassification"
        self._train_batchnorm = True
        self.frozen_model: List[ResNet20ForCIFAR] = []

    def train(self, mode=True):
        """
        Override train() to respect the BatchNorm training setting.
        """
        super().train(mode)
        if mode and not self._train_batchnorm:
            self.set_batchnorm_training(False)
        return self
    
    def set_batchnorm_training(self, train_batchnorm):
        self._train_batchnorm = train_batchnorm
        for module in self.modules():
            if isinstance(module, nn.BatchNorm2d):
                if train_batchnorm:
                    # Normal training behavior: use batch stats and update running stats
                    module.train()
                else:
                    # Freeze behavior: use running stats and don't update them
                    module.eval()

    def forward(self, pixel_values=None, labels=None, output_hidden_states=None, return_dict=None):
        output = super().forward(
            pixel_values=pixel_values, 
            labels=labels, 
            output_hidden_states=output_hidden_states, 
            return_dict=return_dict
        )
        return output

config = ResNetConfig(
    depths=[3, 3, 3],
    downsample_in_first_stage=False,
    embedding_size=16,
    hidden_act="relu",
    hidden_sizes=[16, 32, 64],

    num_labels=10,  
    label2id={"plane": 0, "car": 1, "bird": 2, "cat": 3, "deer": 4, 
              "dog": 5, "frog": 6, "horse": 7, "ship": 8, "truck": 9},
    id2label={0: "plane", 1: "car", 2: "bird", 3: "cat", 4: "deer", 
              5: "dog", 6: "frog", 7: "horse", 8: "ship", 9: "truck"},
    layer_type="basic",
    model_type="resnet",
    num_channels=3,
)

image_processor = ConvNextImageProcessor(
    size={"height": 32, "width": 32},
    image_mean=[0.49139968, 0.48215827, 0.44653124],  # CIFAR-10 mean
    image_std=[0.24703233, 0.24348505, 0.26158768],   # CIFAR-10 std
    do_resize=False,
    do_rescale=True,
    do_normalize=True
)

# Save this processor so it can be loaded with the model later
if not os.path.exists("./resnet20-cifar"):
    os.makedirs("./resnet20-cifar")
    image_processor.save_pretrained("./resnet20-cifar")