import torch.nn as nn
import torch.nn.functional as F
from .base import BaseModel, init_weights, MLP_Classifier

class CNN_Encoder(nn.Module):
    def __init__(self, channel_seq=(3, 64, 128, 256), type = "avgpool"):
        super(CNN_Encoder, self).__init__()
        self.channels = channel_seq
        self.layers = nn.ModuleList()
        for i in range(len(channel_seq) - 1):
            self.layers.append(nn.Conv2d(channel_seq[i], channel_seq[i+1], kernel_size=3, stride=1, padding=1))
            self.layers.append(nn.BatchNorm2d(channel_seq[i+1]))
            self.layers.append(nn.ReLU(inplace=True))
        self.type = type
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.flatten = nn.Flatten()
    
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        if self.type == "avgpool":
            x = self.avgpool(x)
            x = x.view(x.size(0), -1)
        elif self.type == "flatten":
            x = self.flatten(x)
        return x
    
# channel_seq=(64, 128, 254), type="avgpool"
# channel_seq=(64, 64, 64), type="flatten"
def get_cnn_model(input_shape, num_classes, hidden_dim=1024, channel_seq=(64, 64, 64), type="flatten"):
    channel_seq=(input_shape[0], *channel_seq)
    E = CNN_Encoder(channel_seq, type=type)
    input_dim = channel_seq[-1] if type == "avgpool" else channel_seq[-1] * input_shape[1] * input_shape[2]
    E.output_dim = input_dim
    C = MLP_Classifier(input_dim, num_classes, hidden_dim)
    name = f"cnn{'x'.join(map(str, channel_seq))}_{E.type}_mlp{hidden_dim}"
    return BaseModel(E.apply(init_weights), C.apply(init_weights), name)