from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from torch import Tensor

from torch.nn import (Module, Linear, ReLU,
                      Softmax, Sequential, Conv2d,
                      MaxPool2d, AdaptiveAvgPool2d,
                      Flatten)
import torch.nn.init as init


class CNN(Module):

    def __init__(
            self,
            out_dim: int,
            input_channels: int,
            kernel_size: int,
            width: int | list[int],
            *,
            is_classification: bool = True,
            add_bias: bool = True,
            logits_only=True,
    ) -> None:
        super().__init__()

        self.input_channels = input_channels
        self.out_dim = out_dim
        self.width = width
        self.is_class = is_classification
        self.add_bias = add_bias
        self.kernel_size = kernel_size

        if isinstance(self.width, int):
            self.width = [self.width]

        modules = self._build()

        if is_classification and not logits_only:
            modules.append(Softmax(dim=1))

        self.net = Sequential(*modules)
        self._init_weights()

        # self.double()

    def forward(self, x: Tensor) -> Tensor:
        return self.net(x)

    def _build(self) -> list[Module]:

        modules = []
        modules.append(
            Conv2d(
                in_channels=self.input_channels,
                out_channels=self.width[0],
                bias=self.add_bias,
                kernel_size=self.kernel_size,
                stride=1
            )
        )
        modules.append(ReLU())
        modules.append(MaxPool2d(kernel_size=2, stride=2))
        if len(self.width)==2:
            modules.append(
                Conv2d(
                    in_channels=self.width[0],
                    out_channels=self.width[1],
                    bias=self.add_bias,
                    kernel_size=self.kernel_size,
                    stride=2
                )
            )
            modules.append(ReLU())
        elif len(self.width)>2:
            raise ValueError('The CNN module doesn\'t support anything bigger than 2 layers at the moment, I\'m afraid')

        modules.append(AdaptiveAvgPool2d(output_size=(1,1)))
        modules.append(Flatten())
        modules.append(
            Linear(
                in_features=self.width[-1],
                out_features=self.out_dim,
                bias=self.add_bias,
            )
        )
        return modules

    def _init_weights(self):
        for m in self.net.modules():
            if isinstance(m, Conv2d) or isinstance(m, Linear):
                init.kaiming_uniform_(m.weight, nonlinearity='relu')
                if m.bias is not None:
                    init.zeros_(m.bias)