from typing import Optional, Callable

import torch
from torch import nn

from recordclass import RecordClass

from ....Layers import Neuron
from ....Layers.NeuronConfig import NeuronConfig

from ....Normalization import TDBN1D

from ....util import Lift

class StandardClassifier(nn.Module):
    def __init__(
        self,
        num_classes: int = 10,
        neuron: Neuron | None = None,
        params: RecordClass | None = None,
        config: NeuronConfig | None = None,
        norm_layer: Optional[Callable[..., nn.Module]] = None,
        mpbn: bool = False
    ):
        super(StandardClassifier, self).__init__()
        
        self.model = nn.Sequential(
            Lift(nn.LazyLinear(num_classes))
        )

    def forward(self, x) -> torch.Tensor:
        return self.model(x)        
    
class ZhengClassifier(nn.Module):
    def __init__(
        self,
        num_classes: int = 10,
        neuron: Neuron | None = None,
        params: RecordClass | None = None,
        config: NeuronConfig | None = None,
        norm_layer: Optional[Callable[..., nn.Module]] = None,
        mpbn: bool = False
    ):
        super(ZhengClassifier, self).__init__()

        self.model = nn.Sequential(
            Lift(nn.LazyLinear(256)),
            TDBN1D(256, v_th=params.v_th),
            Lift(neuron(params, config)),
            Lift(nn.Linear(256, num_classes))
        )

    def forward(self, x) -> torch.Tensor:
        return self.model(x)
    
class ZhengStandardClassifier(nn.Module):
    def __init__(
        self,
        num_classes: int = 10,
        neuron: Neuron | None = None,
        params: RecordClass | None = None,
        config: NeuronConfig | None = None,
        norm_layer: Optional[Callable[..., nn.Module]] = None,
        mpbn: bool = False
    ):
        super(ZhengStandardClassifier, self).__init__()

        self.model = nn.Sequential(
            Lift(nn.Linear(512, num_classes))
        )

    def forward(self, x) -> torch.Tensor:
        return self.model(x)