import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet18
from functools import partial

torch.set_default_dtype(torch.double)

def get_model(wandb_config):
    if wandb_config['model'] == "CNN":
        model = Net(wandb_config)
    elif wandb_config['model'] == "ResNet18":
        model = ResNet18(wandb_config)
    elif wandb_config['model'] == "ConvNet6":
        model = ConvNet6(wandb_config)
    elif wandb_config['model'] == "Linear":
        model = LinearModel()
    return model

class LinearModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.weights = nn.Linear(32*32*3, 10, dtype=torch.float64, bias=True)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x.view(-1, 32*32*3)
        x = self.weights(x)
        return x

class ResNet18():
    def __new__(self, wandb_config):
        if wandb_config['bn_track_running_stats']:
            model = resnet18()
        else:
            model = resnet18(norm_layer=partial(nn.BatchNorm2d, track_running_stats=False))
        model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        model.maxpool = nn.Identity()  # Remove max pooling to accommodate the smaller image size
        model.fc = nn.Linear(model.fc.in_features, 10)  # Adapt to 10 classes
        return model
    

class Net(nn.Module):
    def __init__(self, wandb_config) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(64, 64, 5)
        self.fc1 = nn.Linear(1600, 1024)
        self.fc2 = nn.Linear(1024, 1024)
        self.fc3 = nn.Linear(1024, 10)
        self.dropout = nn.Dropout(wandb_config['dropout'])

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.dropout(self.pool(F.relu(self.conv1(x))))
        x = self.dropout(self.pool(F.relu(self.conv2(x))))
        x = x.view(-1, 1600)
        x = self.dropout(F.relu(self.fc1(x)))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    

class ConvNet6(nn.Module):
    expansion = 1

    def __init__(self, wandb_config):
        super(ConvNet6, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding='same')
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding='same')
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding='same')
        self.conv4 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding='same')
        self.conv5 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding='same')
        self.conv6 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding='same')

        self.dense1 = nn.Linear(4096, 256)
        self.dense2 = nn.Linear(256, 256)
        self.dense3 = nn.Linear(256, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(F.relu(self.conv2(x)), kernel_size=2, stride=2)
        x = F.relu(self.conv3(x))
        x = F.max_pool2d(F.relu(self.conv4(x)), kernel_size=2, stride=2)
        x = F.relu(self.conv5(x))
        x = F.max_pool2d(F.relu(self.conv6(x)), kernel_size=2, stride=2)
        x = x.view(x.size(0), -1)
        x = F.relu(self.dense1(x))
        x = F.relu(self.dense2(x))
        x = self.dense3(x)
        return x
    