import os
from argparse import Namespace

import torch
from layers import TemperatureScaler
from torch import nn

T = torch.Tensor


class MAML(nn.Module):
    layers: nn.Module
    tuned: bool = False
    do_temp = False

    def init(self) -> None:
        for m in self.modules():
            """
            I Initially attempted to use the xavier initializers in the paper, but the network
            always overfitted to miniimagenet very badly. I was stumped for a long time and finally
            figured out that the initializers were to blame, default pytorch initializers gave good results
            """
            pass


class MAMLOmniglot(MAML):
    def __init__(self, args: Namespace, in_ch: int = 1, filters: int = 64, stride: int = 2):
        super().__init__()

        self.args = args
        self.name = os.path.join("CNN5", f"{filters}-filter")
        self.tmp_layer = TemperatureScaler()

        lyrs = []
        for i in range(4):
            lyrs.extend([
                nn.Conv2d(in_channels=in_ch if i == 0 else filters, out_channels=filters, kernel_size=3, stride=stride, padding=1),
                nn.BatchNorm2d(filters, momentum=1.0, affine=True),
                nn.ReLU(inplace=True),
            ])

        lyrs.extend([nn.AvgPool2d(2), nn.Flatten(start_dim=1), nn.Linear(filters, self.args.n_way)])
        self.layers = nn.Sequential(*lyrs)
        self.init()

    def forward(self, x: T) -> T:
        return self.layers(x)  # type: ignore


class MAMLMiniImageNet(MAML):
    def __init__(self, args: Namespace, in_ch: int = 3, filters: int = 32, stride: int = 1):
        super().__init__()

        self.args = args
        self.name = os.path.join("CNN5", f"{filters}-filter")
        self.tmp_layer = TemperatureScaler()

        lyrs = []
        for i in range(4):
            lyrs.extend([
                nn.Conv2d(in_channels=in_ch if i == 0 else filters, out_channels=filters, kernel_size=3, stride=stride, padding=1),
                nn.BatchNorm2d(filters, momentum=1.0, affine=True),
                nn.ReLU(inplace=True),
                nn.MaxPool2d(2, 2)
            ])

        lyrs.extend([nn.Flatten(start_dim=1), nn.Linear(5 * 5 * filters, self.args.n_way)])
        self.layers = nn.Sequential(*lyrs)
        self.init()

    def forward(self, x: T) -> T:
        return self.layers(x)  # type: ignore
