from typing import Optional, Callable

import torch
from torch import nn

from recordclass import RecordClass

from ....Layers import Neuron, ITLIF, LIFParameters
from ....Layers.NeuronConfig import NeuronConfig

from ....util import Lift

class StandardClassifier(nn.Module):
    def __init__(
        self,
        num_classes: int = 10,
        neuron: Neuron | None = None,
        params: RecordClass | None = None,
        config: NeuronConfig | None = None,
        norm_layer: Optional[Callable[..., nn.Module]] = None
    ):
        super(StandardClassifier, self).__init__()
        
        self.model = nn.Sequential(
            Lift(nn.LazyLinear(num_classes))
        )

    def forward(self, x) -> torch.Tensor:
        return self.model(x)        
    
class ZhengClassifier(nn.Module):
    def __init__(
        self,
        num_classes: int = 10,
        neuron: Neuron | None = None,
        params: RecordClass | None = None,
        config: NeuronConfig | None = None,
        norm_layer: Optional[Callable[..., nn.Module]] = None
    ):
        super(ZhengClassifier, self).__init__()

        self.model = nn.Sequential(
            Lift(nn.LazyLinear(256)),
            Lift(neuron(params, config)),
            Lift(nn.Linear(256, num_classes))
        )

    def forward(self, x) -> torch.Tensor:
        return self.model(x)
    
class ZhengStandardClassifier(nn.Module):
    def __init__(
        self,
        num_classes: int = 10,
        neuron: Neuron | None = None,
        params: RecordClass | None = None,
        config: NeuronConfig | None = None,
        norm_layer: Optional[Callable[..., nn.Module]] = None
    ):
        super(ZhengStandardClassifier, self).__init__()

        self.model = nn.Sequential(
            Lift(nn.LazyLinear(num_classes))
        )

    def forward(self, x) -> torch.Tensor:
        return self.model(x)
    