from dataclasses import dataclass
from collections import OrderedDict
import math

import torch
import torch.nn as nn
# import torch.nn.functional as F

from ..wrapping import WrappedModel, InferenceRecord
from ..wrapping import layer_names as ln
from .lib.recorded_sequential import RecordedSequential


CONV = "conv"
BATCH_NORM = "bn"
ACTIVATION = "act"
MAX_POOL = "mp"
FLATTEN = "fl"
LINEAR = "fc"


@dataclass
class SimpleConvConfig:
    id: str
    input_size: int
    conv_layer_filters: list[int]
    fc_layer_units: list[int]
    input_channels: int = 3
    conv_size: int = 3


class SimpleConv(WrappedModel):
    def __init__(self, config: SimpleConvConfig) -> None:
        super().__init__()#id=config.id)

        layers: list[tuple[str, nn.Module]] = []
        prev_size = config.input_size
        prev_filters = config.input_channels
        for i, n_filters in enumerate(config.conv_layer_filters):
            layers.append((f"{ln.CONV}_{i}", nn.Conv2d(
                prev_filters,
                n_filters,
                config.conv_size,
            )))
            layers.append((f"{ln.BATCH_NORM}_{i}", nn.BatchNorm2d(n_filters)))
            layers.append((f"{ln.ACTIVATION}_{i}", nn.ReLU()))

            prev_size -= (config.conv_size - 1)
            prev_filters = n_filters

        layers.append((ln.MAX_POOL, nn.MaxPool2d(config.conv_size)))
        prev_size = math.floor(prev_size / config.conv_size)
        layers.append((ln.FLATTEN, nn.Flatten()))

        prev_size = prev_size * prev_size * config.conv_layer_filters[-1]
        for i, n_units in enumerate(config.fc_layer_units):
            layers.append((f"{ln.LINEAR}_{i}", nn.Linear(prev_size, n_units)))
            if i < len(config.fc_layer_units) - 1:
                layers.append((f"{ln.ACTIVATION}_l{i}", nn.ReLU()))
            prev_size = n_units

        self.layers = RecordedSequential(OrderedDict(layers))
        # self.layers = nn.Sequential(*layers)

        # self.conv1 = nn.Conv2d(3, 6, 3)
        # self.conv2 = nn.Conv2d(6, 16, 3)
        # self.pool = nn.MaxPool2d(2, 2)
        # self.fc1 = nn.Linear(16 * 6 * 6, 120)
        # self.fc2 = nn.Linear(120, 84)
        # self.fc3 = nn.Linear(84, 10)

    def forward(self, x: torch.Tensor) -> InferenceRecord:
        # x = self.pool(F.relu(self.conv1(x)))
        # x = self.pool(F.relu(self.conv2(x)))
        # x = torch.flatten(x, 1)
        # x = F.relu(self.fc1(x))
        # x = F.relu(self.fc2(x))
        # x = self.fc3(x)

        # x = self.layers(x)
        # return ModelRecord(final_activations=x)

        return self.layers(x)
