import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

from .efficientnets import EfficientNets
from .densenets import DenseNets

transform_params = {"efficientnet_b0":{"resize_size":256, "crop_size":224},
                    "efficientnet_b1":{"resize_size":256, "crop_size":240},
                    "efficientnet_b2":{"resize_size":288, "crop_size":288},
                    "efficientnet_b3":{"resize_size":320, "crop_size":300},
                    "efficientnet_b4":{"resize_size":384, "crop_size":380},
                    "efficientnet_b5":{"resize_size":456, "crop_size":456},
                    "efficientnet_b6":{"resize_size":528, "crop_size":528},
                    "efficientnet_b7":{"resize_size":600, "crop_size":600},
                    }

def get_transform(name="efficientnet_b0", train=False):
    if "efficientnet" in name:
        params = transform_params[name]
    else:
        params = {"resize_size":256, "crop_size":224}
    if train:
        transform = transforms.Compose([
                                transforms.Resize(params["resize_size"], transforms.InterpolationMode.BICUBIC),
                                transforms.RandomCrop(params["crop_size"]),
                                transforms.RandomHorizontalFlip(),
                                transforms.ToTensor(),
                                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                    std=[0.229, 0.224, 0.225]),
                                ]
                            )

    else:
        transform = transforms.Compose([
                                transforms.Resize(params["resize_size"], transforms.InterpolationMode.BICUBIC),
                                transforms.CenterCrop(params["crop_size"]),
                                transforms.ToTensor(),
                                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                    std=[0.229, 0.224, 0.225]),
                                ]
                            )

    return transform
    
def Net(name='AlexNet', num_classes=1860, pretrained=False):
    if pretrained:
        w='IMAGENET1K_V1'
    else:
        w=None

    if "efficientnet" in name:
        network = EfficientNets(name=name, num_classes=num_classes, pretrained=pretrained)
    
    elif "densenet" in name:
        network = DenseNets(name=name, num_classes=num_classes, pretrained=pretrained)
    
    if name=='AlexNet':
        network = AlexNet(num_classes=num_classes)
    elif name=='smallAlexNet':
        network = smallAlexNet(num_classes=num_classes)
    elif name=='tinyAlexNet':
        network = tinyAlexNet(num_classes=num_classes)
    elif name=='resnet18':
        from resnet import resnet18
        network = resnet18(num_classes=num_classes)
    elif name=='resnet10':
        from resnet import resnet10
        network = resnet10(num_classes=num_classes)
    elif name=='CNN':
        network = CNN(num_classes=num_classes)
    elif name=='smallCNN':
        network = smallCNN(num_classes=num_classes)
    elif name=='tinyCNN':
        network = tinyCNN(num_classes=num_classes)
    
    return network

class AlexNet(nn.Module):
    def __init__(self, num_classes: int = 1860, dropout: float = 0.5) -> None:
        super().__init__()
        # _log_api_usage_once(self)
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
        self.classifier = nn.Sequential(
            nn.Dropout(p=dropout),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(p=dropout),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

class smallAlexNet(nn.Module):
    def __init__(self, num_classes: int = 1860, dropout: float = 0.5) -> None:
        super().__init__()
        # _log_api_usage_once(self)
        self.features = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(16, 32, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
        self.classifier = nn.Sequential(
            nn.Dropout(p=dropout),
            nn.Linear(32 * 6 * 6, 2048),
            nn.ReLU(inplace=True),
            nn.Dropout(p=dropout),
            nn.Linear(2048, 2048),
            nn.ReLU(inplace=True),
            nn.Linear(2048, num_classes),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

class tinyAlexNet(nn.Module):
    def __init__(self, num_classes: int = 1860, dropout: float = 0.5) -> None:
        super().__init__()
        # _log_api_usage_once(self)
        self.features = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(16, 32, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
        self.classifier = nn.Sequential(
            nn.Dropout(p=dropout),
            nn.Linear(32 * 6 * 6, 2048),
            nn.ReLU(inplace=True),
            nn.Dropout(p=dropout),
            nn.Linear(2048, num_classes),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

class CNN(nn.Module):
    def __init__(self, num_classes: int = 10):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


class smallCNN(nn.Module):
    def __init__(self, num_classes: int = 10):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 3, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(3, 9, 5)
        self.fc1 = nn.Linear(9 * 4 * 4, 64)
        self.fc2 = nn.Linear(64, num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

class tinyCNN(nn.Module):
    def __init__(self, num_classes: int = 10):
        super().__init__()
        self.conv = nn.Conv2d(1, 3, 5)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.pool2 = nn.MaxPool2d(3, 2)
        self.fc = nn.Linear(3 * 5 * 5, num_classes)

    def forward(self, x):
        x = self.pool2(self.pool1(F.relu(self.conv(x))))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = self.fc(x)
        return x

class TinyCNNs(nn.Module):
    def __init__(self, num_classes: int = 10, num_layers: int = 1, in_channels=1):
        super().__init__()
        convs = []
        for i in range(num_layers):
            convs.append(nn.Conv2d(2**(i//2)*in_channels, 2**((i+1)//2)*in_channels, 5, padding=2))
            convs.append(nn.ReLU())
        self.convs = nn.Sequential(*convs)
        # self.AdaptativeScale = int(5*2**(num_layers/2))
        self.pool = nn.AdaptiveAvgPool2d(5)
        self.fc = nn.Linear(25*2**(num_layers//2)*in_channels, num_classes)

    def forward(self, x):
        x = self.convs(x)
        x = F.relu(x)
        x = self.pool(x)
        x = torch.flatten(x, start_dim=1)
        x = self.fc(x)
        return x

class TinyNets(nn.Module):
    def __init__(self, num_classes: int = 10, num_layers: int = 1):
        super().__init__()
        convs = []
        for i in range(num_layers):
            convs.append(nn.Conv2d(2**(i//2), 2**((i+1)//2), 5, padding=2))
        self.convs = nn.Sequential(*convs)
        # self.AdaptativeScale = int(5*2**(num_layers/2))
        self.pool = nn.AdaptiveAvgPool2d(5)
        self.fc = nn.Linear(25*2**(num_layers//2), num_classes)

    def forward(self, x):
        x = self.convs(x)
        x = F.relu(x)
        x = self.pool(x)
        x = torch.flatten(x, start_dim=1)
        x = self.fc(x)
        return x


