import torch
import torch.nn as nn
from torchdiffeq import odeint_adjoint, odeint
import numpy as np
import torch.nn.functional as F


class ODENet(nn.Module):
    def __init__(
        self,
        in_ch,
        out=10,
        n_filters=64,
        downsample="residual",
        method="dopri5",
        tol=1e-3,
        adjoint=False,
        t1=1,
        dropout=0,
        norm="group",
    ):
        super(ODENet, self).__init__()

        common = dict(out_ch=n_filters, norm=norm)
        if downsample == "residual":
            self.downsample = ResDownsample(in_ch, **common)
        elif downsample == "convolution":
            self.downsample = ConvDownsample(in_ch, **common)
        elif downsample == "minimal":
            self.downsample = MinimalConvDownsample(in_ch, **common)
        elif downsample == "one-shot":
            self.downsample = OneShotDownsample(in_ch, **common)
        elif downsample == "ode":
            self.downsample = ODEDownsample(
                in_ch, adjoint=adjoint, t1=t1, tol=tol, method=method, **common
            )
        elif downsample == "ode2":
            self.downsample = ODEDownsample2(
                in_ch, adjoint=adjoint, t1=t1, tol=tol, method=method, **common
            )

        self.odeblock = ODEBlock(
            n_filters=n_filters,
            tol=tol,
            adjoint=adjoint,
            t1=t1,
            method=method,
            norm=norm,
        )
        self.classifier = FCClassifier(
            in_ch=n_filters, out=out, dropout=dropout, norm=norm
        )

    def forward(self, x):
        out = []
        x = self.downsample(x)
        if isinstance(x, (tuple, list)):
            (
                f,
                x,
            ) = x  # first elements are features, second is output to continue the forward
            if isinstance(
                self.classifier.module[-1], nn.Sequential
            ):  # no classification to be performed, apply GAP and return
                f = torch.stack(
                    [fi.mean(-1).mean(-1) for fi in f]
                )  # global avg pooling
            else:  # we want to apply the classifier to all features
                f = torch.stack([self.classifier(fi) for fi in f])
            out.append(f)

        x = self.odeblock(x)
        if x.dim() > 4:
            x = torch.stack([self.classifier(xi) for xi in x])
        else:
            x = self.classifier(x)

        out.append(x)
        out = torch.cat(out)
        return out

    def to_features_extractor(self, keep_pool=True):  # ugly hack
        if isinstance(self.downsample, (ODEDownsample, ODEDownsample2)):
            self.downsample.odeblock.return_last_only = False
        self.odeblock.return_last_only = False  # returns dynamic @ multiple timestamps
        if keep_pool:
            # remove last classification layer but maintain norm, relu and global avg pooling
            self.classifier.module[-1] = nn.Sequential()
        else:
            self.classifier = nn.Sequential(
                *list(self.classifier.module.children())[:2]
            )

    def nfe(self, reset=False):
        nfe = self.odeblock.nfe
        if reset:
            self.odeblock.nfe = 0
        return nfe


### my
class ODENetSpatial(nn.Module):
    def __init__(
        self,
        in_ch,
        out=10,
        n_filters=64,
        downsample="residual",
        method="dopri5",
        tol=1e-3,
        adjoint=False,
        t1=1,
        dropout=0,
        norm="group",
    ):
        super(ODENetSpatial, self).__init__()

        common = dict(out_ch=n_filters, norm=norm)
        if downsample == "residual":
            self.downsample = ResDownsample(in_ch, **common)
        elif downsample == "convolution":
            self.downsample = ConvDownsample(in_ch, **common)
        elif downsample == "minimal":
            self.downsample = MinimalConvDownsample(in_ch, **common)
        elif downsample == "one-shot":
            self.downsample = OneShotDownsample(in_ch, **common)
        elif downsample == "ode":
            self.downsample = ODEDownsample(
                in_ch, adjoint=adjoint, t1=t1, tol=tol, method=method, **common
            )
        elif downsample == "ode2":
            self.downsample = ODEDownsample2(
                in_ch, adjoint=adjoint, t1=t1, tol=tol, method=method, **common
            )

        self.odeblock = ODEBlockSpatial(
            n_filters=n_filters,
            tol=tol,
            adjoint=adjoint,
            t1=t1,
            method=method,
            norm=norm,
        )
        self.classifier = FCClassifier(
            in_ch=n_filters, out=out, dropout=dropout, norm=norm
        )

    def forward(self, x):
        out = []
        x = self.downsample(x)
        if isinstance(x, (tuple, list)):
            (
                f,
                x,
            ) = x  # first elements are features, second is output to continue the forward
            if isinstance(
                self.classifier.module[-1], nn.Sequential
            ):  # no classification to be performed, apply GAP and return
                f = torch.stack(
                    [fi.mean(-1).mean(-1) for fi in f]
                )  # global avg pooling
            else:  # we want to apply the classifier to all features
                f = torch.stack([self.classifier(fi) for fi in f])
            out.append(f)

        x = self.odeblock(x)
        if x.dim() > 4:
            x = torch.stack([self.classifier(xi) for xi in x])
        else:
            x = self.classifier(x)

        out.append(x)
        out = torch.cat(out)
        return out

    def to_features_extractor(self, keep_pool=True):  # ugly hack
        if isinstance(self.downsample, (ODEDownsample, ODEDownsample2)):
            self.downsample.odeblock.return_last_only = False
        self.odeblock.return_last_only = False  # returns dynamic @ multiple timestamps
        if keep_pool:
            # remove last classification layer but maintain norm, relu and global avg pooling
            self.classifier.module[-1] = nn.Sequential()
        else:
            self.classifier = nn.Sequential(
                *list(self.classifier.module.children())[:2]
            )

    def nfe(self, reset=False):
        nfe = self.odeblock.nfe
        if reset:
            self.odeblock.nfe = 0
        return nfe


### my
class ODENetSpatialPlusTwoMul(nn.Module):
    def __init__(
        self,
        in_ch,
        out=10,
        n_filters=64,
        downsample="residual",
        method="dopri5",
        tol=1e-3,
        adjoint=False,
        t1=1,
        dropout=0,
        norm="group",
    ):
        super(ODENetSpatialPlusTwoMul, self).__init__()

        common = dict(out_ch=n_filters, norm=norm)
        if downsample == "residual":
            self.downsample = ResDownsample(in_ch, **common)
        elif downsample == "convolution":
            self.downsample = ConvDownsample(in_ch, **common)
        elif downsample == "minimal":
            self.downsample = MinimalConvDownsample(in_ch, **common)
        elif downsample == "one-shot":
            self.downsample = OneShotDownsample(in_ch, **common)
        elif downsample == "ode":
            self.downsample = ODEDownsample(
                in_ch, adjoint=adjoint, t1=t1, tol=tol, method=method, **common
            )
        elif downsample == "ode2":
            self.downsample = ODEDownsample2(
                in_ch, adjoint=adjoint, t1=t1, tol=tol, method=method, **common
            )

        self.odeblock = ODEBlockSpatialPlusTwoMul(
            n_filters=n_filters,
            tol=tol,
            adjoint=adjoint,
            t1=t1,
            method=method,
            norm=norm,
        )
        self.classifier = FCClassifier(
            in_ch=n_filters, out=out, dropout=dropout, norm=norm
        )

    def forward(self, x):
        out = []
        x = self.downsample(x)
        if isinstance(x, (tuple, list)):
            (
                f,
                x,
            ) = x  # first elements are features, second is output to continue the forward
            if isinstance(
                self.classifier.module[-1], nn.Sequential
            ):  # no classification to be performed, apply GAP and return
                f = torch.stack(
                    [fi.mean(-1).mean(-1) for fi in f]
                )  # global avg pooling
            else:  # we want to apply the classifier to all features
                f = torch.stack([self.classifier(fi) for fi in f])
            out.append(f)

        x = self.odeblock(x)
        if x.dim() > 4:
            x = torch.stack([self.classifier(xi) for xi in x])
        else:
            x = self.classifier(x)

        out.append(x)
        out = torch.cat(out)
        return out

    def to_features_extractor(self, keep_pool=True):  # ugly hack
        if isinstance(self.downsample, (ODEDownsample, ODEDownsample2)):
            self.downsample.odeblock.return_last_only = False
        self.odeblock.return_last_only = False  # returns dynamic @ multiple timestamps
        if keep_pool:
            # remove last classification layer but maintain norm, relu and global avg pooling
            self.classifier.module[-1] = nn.Sequential()
        else:
            self.classifier = nn.Sequential(
                *list(self.classifier.module.children())[:2]
            )

    def nfe(self, reset=False):
        nfe = self.odeblock.nfe
        if reset:
            self.odeblock.nfe = 0
        return nfe


### my
class ODENetSigmoid(nn.Module):
    def __init__(
        self,
        in_ch,
        out=10,
        n_filters=64,
        downsample="residual",
        method="dopri5",
        tol=1e-3,
        adjoint=False,
        t1=1,
        dropout=0,
        norm="group",
    ):
        super(ODENetSigmoid, self).__init__()

        common = dict(out_ch=n_filters, norm=norm)
        if downsample == "residual":
            self.downsample = ResDownsample(in_ch, **common)
        elif downsample == "convolution":
            self.downsample = ConvDownsample(in_ch, **common)
        elif downsample == "minimal":
            self.downsample = MinimalConvDownsample(in_ch, **common)
        elif downsample == "one-shot":
            self.downsample = OneShotDownsample(in_ch, **common)
        elif downsample == "ode":
            self.downsample = ODEDownsample(
                in_ch, adjoint=adjoint, t1=t1, tol=tol, method=method, **common
            )
        elif downsample == "ode2":
            self.downsample = ODEDownsample2(
                in_ch, adjoint=adjoint, t1=t1, tol=tol, method=method, **common
            )

        self.odeblock = ODEBlockSigmoid(
            n_filters=n_filters,
            tol=tol,
            adjoint=adjoint,
            t1=t1,
            method=method,
            norm=norm,
        )
        self.classifier = FCClassifier(
            in_ch=n_filters, out=out, dropout=dropout, norm=norm
        )

    def forward(self, x):
        out = []
        x = self.downsample(x)
        if isinstance(x, (tuple, list)):
            (
                f,
                x,
            ) = x  # first elements are features, second is output to continue the forward
            if isinstance(
                self.classifier.module[-1], nn.Sequential
            ):  # no classification to be performed, apply GAP and return
                f = torch.stack(
                    [fi.mean(-1).mean(-1) for fi in f]
                )  # global avg pooling
            else:  # we want to apply the classifier to all features
                f = torch.stack([self.classifier(fi) for fi in f])
            out.append(f)

        x = self.odeblock(x)
        if x.dim() > 4:
            x = torch.stack([self.classifier(xi) for xi in x])
        else:
            x = self.classifier(x)

        out.append(x)
        out = torch.cat(out)
        return out

    def to_features_extractor(self, keep_pool=True):  # ugly hack
        if isinstance(self.downsample, (ODEDownsample, ODEDownsample2)):
            self.downsample.odeblock.return_last_only = False
        self.odeblock.return_last_only = False  # returns dynamic @ multiple timestamps
        if keep_pool:
            # remove last classification layer but maintain norm, relu and global avg pooling
            self.classifier.module[-1] = nn.Sequential()
        else:
            self.classifier = nn.Sequential(
                *list(self.classifier.module.children())[:2]
            )

    def nfe(self, reset=False):
        nfe = self.odeblock.nfe
        if reset:
            self.odeblock.nfe = 0
        return nfe


### my
class ODEBlockSigmoid(nn.Module):
    def __init__(
        self, n_filters=64, tol=1e-3, method="dopri5", adjoint=False, t1=1, norm="group"
    ):
        super(ODEBlockSigmoid, self).__init__()
        self.odefunc = ODEfuncSigmoid(n_filters, norm=norm)
        self.t1 = t1
        self.tol = tol
        self.method = method
        self.odeint = odeint_adjoint if adjoint else odeint
        self.return_last_only = True

    def forward(self, x):
        if self.integration_time is None:
            return x

        self.integration_time = self.integration_time.type_as(x)
        out = self.odeint(
            self.odefunc,
            x,
            self.integration_time,
            method=self.method,
            rtol=self.tol,
            atol=self.tol,
        )
        if self.return_last_only:
            out = out[-1]  # dynamics @ t=t1
        return out

    @property
    def nfe(self):
        return self.odefunc.nfe

    @nfe.setter
    def nfe(self, value):
        self.odefunc.nfe = value

    @property
    def t1(self):
        return self.integration_time[1]

    @t1.setter
    def t1(self, value):
        if isinstance(value, (int, float)):
            if value == 0:
                self.integration_time = None
            else:
                self.integration_time = torch.tensor([0, value], dtype=torch.float32)

        elif isinstance(value, (list, tuple, torch.Tensor)):
            if isinstance(value, tuple):
                value = list(value)
            if isinstance(value, torch.Tensor):
                value = value.tolist()
            if value[0] != 0:
                print(value[0])
                value = [
                    0,
                ] + value

            self.integration_time = torch.tensor(value, dtype=torch.float32)
        else:
            raise ValueError("Argument must be a scalar, a list, or a tensor")


### my
class ODEfuncSigmoid(nn.Module):
    def __init__(self, dim, norm="group"):
        super(ODEfuncSigmoid, self).__init__()
        norm = normalization(norm)
        self.norm1 = norm(dim)
        self.sigmoid = nn.Sigmoid()
        self.conv1 = ConcatConv2d(dim, dim, kernel_size=3, stride=1, padding=1)
        self.norm2 = norm(dim)
        self.conv2 = ConcatConv2d(dim, dim, kernel_size=3, stride=1, padding=1)
        self.norm3 = norm(dim)
        self.nfe = 0

    def forward(self, t, x):
        self.nfe += 1
        out = self.norm1(x)
        out = self.sigmoid(out)
        out = self.conv1(t, out)
        out = self.norm2(out)
        out = self.sigmoid(out)
        out = self.conv2(t, out)
        out = self.norm3(out)
        return out


### my
class ODENetSiLU(nn.Module):
    def __init__(
        self,
        in_ch,
        out=10,
        n_filters=64,
        downsample="residual",
        method="dopri5",
        tol=1e-3,
        adjoint=False,
        t1=1,
        dropout=0,
        norm="group",
    ):
        super(ODENetSiLU, self).__init__()

        common = dict(out_ch=n_filters, norm=norm)
        if downsample == "residual":
            self.downsample = ResDownsample(in_ch, **common)
        elif downsample == "convolution":
            self.downsample = ConvDownsample(in_ch, **common)
        elif downsample == "minimal":
            self.downsample = MinimalConvDownsample(in_ch, **common)
        elif downsample == "one-shot":
            self.downsample = OneShotDownsample(in_ch, **common)
        elif downsample == "ode":
            self.downsample = ODEDownsample(
                in_ch, adjoint=adjoint, t1=t1, tol=tol, method=method, **common
            )
        elif downsample == "ode2":
            self.downsample = ODEDownsample2(
                in_ch, adjoint=adjoint, t1=t1, tol=tol, method=method, **common
            )

        self.odeblock = ODEBlockSiLU(
            n_filters=n_filters,
            tol=tol,
            adjoint=adjoint,
            t1=t1,
            method=method,
            norm=norm,
        )
        self.classifier = FCClassifier(
            in_ch=n_filters, out=out, dropout=dropout, norm=norm
        )

    def forward(self, x):
        out = []
        x = self.downsample(x)
        if isinstance(x, (tuple, list)):
            (
                f,
                x,
            ) = x  # first elements are features, second is output to continue the forward
            if isinstance(
                self.classifier.module[-1], nn.Sequential
            ):  # no classification to be performed, apply GAP and return
                f = torch.stack(
                    [fi.mean(-1).mean(-1) for fi in f]
                )  # global avg pooling
            else:  # we want to apply the classifier to all features
                f = torch.stack([self.classifier(fi) for fi in f])
            out.append(f)

        x = self.odeblock(x)
        if x.dim() > 4:
            x = torch.stack([self.classifier(xi) for xi in x])
        else:
            x = self.classifier(x)

        out.append(x)
        out = torch.cat(out)
        return out

    def to_features_extractor(self, keep_pool=True):  # ugly hack
        if isinstance(self.downsample, (ODEDownsample, ODEDownsample2)):
            self.downsample.odeblock.return_last_only = False
        self.odeblock.return_last_only = False  # returns dynamic @ multiple timestamps
        if keep_pool:
            # remove last classification layer but maintain norm, relu and global avg pooling
            self.classifier.module[-1] = nn.Sequential()
        else:
            self.classifier = nn.Sequential(
                *list(self.classifier.module.children())[:2]
            )

    def nfe(self, reset=False):
        nfe = self.odeblock.nfe
        if reset:
            self.odeblock.nfe = 0
        return nfe


### my
class ODEBlockSiLU(nn.Module):
    def __init__(
        self, n_filters=64, tol=1e-3, method="dopri5", adjoint=False, t1=1, norm="group"
    ):
        super(ODEBlockSiLU, self).__init__()
        self.odefunc = ODEfuncSiLU(n_filters, norm=norm)
        self.t1 = t1
        self.tol = tol
        self.method = method
        self.odeint = odeint_adjoint if adjoint else odeint
        self.return_last_only = True

    def forward(self, x):
        if self.integration_time is None:
            return x

        self.integration_time = self.integration_time.type_as(x)
        out = self.odeint(
            self.odefunc,
            x,
            self.integration_time,
            method=self.method,
            rtol=self.tol,
            atol=self.tol,
        )
        if self.return_last_only:
            out = out[-1]  # dynamics @ t=t1
        return out

    @property
    def nfe(self):
        return self.odefunc.nfe

    @nfe.setter
    def nfe(self, value):
        self.odefunc.nfe = value

    @property
    def t1(self):
        return self.integration_time[1]

    @t1.setter
    def t1(self, value):
        if isinstance(value, (int, float)):
            if value == 0:
                self.integration_time = None
            else:
                self.integration_time = torch.tensor([0, value], dtype=torch.float32)

        elif isinstance(value, (list, tuple, torch.Tensor)):
            if isinstance(value, tuple):
                value = list(value)
            if isinstance(value, torch.Tensor):
                value = value.tolist()
            if value[0] != 0:
                print(value[0])
                value = [
                    0,
                ] + value

            self.integration_time = torch.tensor(value, dtype=torch.float32)
        else:
            raise ValueError("Argument must be a scalar, a list, or a tensor")


### my
class ODEfuncSiLU(nn.Module):
    def __init__(self, dim, norm="group"):
        super(ODEfuncSiLU, self).__init__()
        norm = normalization(norm)
        self.norm1 = norm(dim)
        self.silu = nn.SiLU()
        self.conv1 = ConcatConv2d(dim, dim, kernel_size=3, stride=1, padding=1)
        self.norm2 = norm(dim)
        self.conv2 = ConcatConv2d(dim, dim, kernel_size=3, stride=1, padding=1)
        self.norm3 = norm(dim)
        self.nfe = 0

    def forward(self, t, x):
        self.nfe += 1
        out = self.norm1(x)
        out = self.silu(out)
        out = self.conv1(t, out)
        out = self.norm2(out)
        out = self.silu(out)
        out = self.conv2(t, out)
        out = self.norm3(out)
        return out


### my
class ODENetSiLUReLU(nn.Module):
    def __init__(
        self,
        in_ch,
        out=10,
        n_filters=64,
        downsample="residual",
        method="dopri5",
        tol=1e-3,
        adjoint=False,
        t1=1,
        dropout=0,
        norm="group",
    ):
        super(ODENetSiLUReLU, self).__init__()

        common = dict(out_ch=n_filters, norm=norm)
        if downsample == "residual":
            self.downsample = ResDownsample(in_ch, **common)
        elif downsample == "convolution":
            self.downsample = ConvDownsample(in_ch, **common)
        elif downsample == "minimal":
            self.downsample = MinimalConvDownsample(in_ch, **common)
        elif downsample == "one-shot":
            self.downsample = OneShotDownsample(in_ch, **common)
        elif downsample == "ode":
            self.downsample = ODEDownsample(
                in_ch, adjoint=adjoint, t1=t1, tol=tol, method=method, **common
            )
        elif downsample == "ode2":
            self.downsample = ODEDownsample2(
                in_ch, adjoint=adjoint, t1=t1, tol=tol, method=method, **common
            )

        self.odeblock = ODEBlockSiLUReLU(
            n_filters=n_filters,
            tol=tol,
            adjoint=adjoint,
            t1=t1,
            method=method,
            norm=norm,
        )
        self.classifier = FCClassifier(
            in_ch=n_filters, out=out, dropout=dropout, norm=norm
        )

    def forward(self, x):
        out = []
        x = self.downsample(x)
        if isinstance(x, (tuple, list)):
            (
                f,
                x,
            ) = x  # first elements are features, second is output to continue the forward
            if isinstance(
                self.classifier.module[-1], nn.Sequential
            ):  # no classification to be performed, apply GAP and return
                f = torch.stack(
                    [fi.mean(-1).mean(-1) for fi in f]
                )  # global avg pooling
            else:  # we want to apply the classifier to all features
                f = torch.stack([self.classifier(fi) for fi in f])
            out.append(f)

        x = self.odeblock(x)
        if x.dim() > 4:
            x = torch.stack([self.classifier(xi) for xi in x])
        else:
            x = self.classifier(x)

        out.append(x)
        out = torch.cat(out)
        return out

    def to_features_extractor(self, keep_pool=True):  # ugly hack
        if isinstance(self.downsample, (ODEDownsample, ODEDownsample2)):
            self.downsample.odeblock.return_last_only = False
        self.odeblock.return_last_only = False  # returns dynamic @ multiple timestamps
        if keep_pool:
            # remove last classification layer but maintain norm, relu and global avg pooling
            self.classifier.module[-1] = nn.Sequential()
        else:
            self.classifier = nn.Sequential(
                *list(self.classifier.module.children())[:2]
            )

    def nfe(self, reset=False):
        nfe = self.odeblock.nfe
        if reset:
            self.odeblock.nfe = 0
        return nfe


### my
class ODEBlockSiLUReLU(nn.Module):
    def __init__(
        self, n_filters=64, tol=1e-3, method="dopri5", adjoint=False, t1=1, norm="group"
    ):
        super(ODEBlockSiLUReLU, self).__init__()
        self.odefunc = ODEfuncSiLUReLU(n_filters, norm=norm)
        self.t1 = t1
        self.tol = tol
        self.method = method
        self.odeint = odeint_adjoint if adjoint else odeint
        self.return_last_only = True

    def forward(self, x):
        if self.integration_time is None:
            return x

        self.integration_time = self.integration_time.type_as(x)
        out = self.odeint(
            self.odefunc,
            x,
            self.integration_time,
            method=self.method,
            rtol=self.tol,
            atol=self.tol,
        )
        if self.return_last_only:
            out = out[-1]  # dynamics @ t=t1
        return out

    @property
    def nfe(self):
        return self.odefunc.nfe

    @nfe.setter
    def nfe(self, value):
        self.odefunc.nfe = value

    @property
    def t1(self):
        return self.integration_time[1]

    @t1.setter
    def t1(self, value):
        if isinstance(value, (int, float)):
            if value == 0:
                self.integration_time = None
            else:
                self.integration_time = torch.tensor([0, value], dtype=torch.float32)

        elif isinstance(value, (list, tuple, torch.Tensor)):
            if isinstance(value, tuple):
                value = list(value)
            if isinstance(value, torch.Tensor):
                value = value.tolist()
            if value[0] != 0:
                print(value[0])
                value = [
                    0,
                ] + value

            self.integration_time = torch.tensor(value, dtype=torch.float32)
        else:
            raise ValueError("Argument must be a scalar, a list, or a tensor")


### my
class ODEfuncSiLUReLU(nn.Module):
    def __init__(self, dim, norm="group"):
        super(ODEfuncSiLUReLU, self).__init__()
        norm = normalization(norm)
        self.norm1 = norm(dim)
        self.silurelu = SiLUReLU()
        self.conv1 = ConcatConv2d(dim, dim, kernel_size=3, stride=1, padding=1)
        self.norm2 = norm(dim)
        self.conv2 = ConcatConv2d(dim, dim, kernel_size=3, stride=1, padding=1)
        self.norm3 = norm(dim)
        self.nfe = 0

    def forward(self, t, x):
        self.nfe += 1
        out = self.norm1(x)
        out = self.silurelu(t, out)
        out = self.conv1(t, out)
        out = self.norm2(out)
        out = self.silurelu(t, out)
        out = self.conv2(t, out)
        out = self.norm3(out)
        return out


class SiLUReLU(nn.Module):
    def __init__(self):
        super(SiLUReLU, self).__init__()
        self.sigmoid = nn.Sigmoid()
        self.eps = 1e-5

    def forward(self, t, x):
        return torch.mul(x, self.sigmoid(x / (t + self.eps)))


### my
class ODENetSiLUSigmoid(nn.Module):
    def __init__(
        self,
        in_ch,
        out=10,
        n_filters=64,
        downsample="residual",
        method="dopri5",
        tol=1e-3,
        adjoint=False,
        t1=1,
        dropout=0,
        norm="group",
    ):
        super(ODENetSiLUSigmoid, self).__init__()

        common = dict(out_ch=n_filters, norm=norm)
        if downsample == "residual":
            self.downsample = ResDownsample(in_ch, **common)
        elif downsample == "convolution":
            self.downsample = ConvDownsample(in_ch, **common)
        elif downsample == "minimal":
            self.downsample = MinimalConvDownsample(in_ch, **common)
        elif downsample == "one-shot":
            self.downsample = OneShotDownsample(in_ch, **common)
        elif downsample == "ode":
            self.downsample = ODEDownsample(
                in_ch, adjoint=adjoint, t1=t1, tol=tol, method=method, **common
            )
        elif downsample == "ode2":
            self.downsample = ODEDownsample2(
                in_ch, adjoint=adjoint, t1=t1, tol=tol, method=method, **common
            )

        self.odeblock = ODEBlockSiLUSigmoid(
            n_filters=n_filters,
            tol=tol,
            adjoint=adjoint,
            t1=t1,
            method=method,
            norm=norm,
        )
        self.classifier = FCClassifier(
            in_ch=n_filters, out=out, dropout=dropout, norm=norm
        )

    def forward(self, x):
        out = []
        x = self.downsample(x)
        if isinstance(x, (tuple, list)):
            (
                f,
                x,
            ) = x  # first elements are features, second is output to continue the forward
            if isinstance(
                self.classifier.module[-1], nn.Sequential
            ):  # no classification to be performed, apply GAP and return
                f = torch.stack(
                    [fi.mean(-1).mean(-1) for fi in f]
                )  # global avg pooling
            else:  # we want to apply the classifier to all features
                f = torch.stack([self.classifier(fi) for fi in f])
            out.append(f)

        x = self.odeblock(x)
        if x.dim() > 4:
            x = torch.stack([self.classifier(xi) for xi in x])
        else:
            x = self.classifier(x)

        out.append(x)
        out = torch.cat(out)
        return out

    def to_features_extractor(self, keep_pool=True):  # ugly hack
        if isinstance(self.downsample, (ODEDownsample, ODEDownsample2)):
            self.downsample.odeblock.return_last_only = False
        self.odeblock.return_last_only = False  # returns dynamic @ multiple timestamps
        if keep_pool:
            # remove last classification layer but maintain norm, relu and global avg pooling
            self.classifier.module[-1] = nn.Sequential()
        else:
            self.classifier = nn.Sequential(
                *list(self.classifier.module.children())[:2]
            )

    def nfe(self, reset=False):
        nfe = self.odeblock.nfe
        if reset:
            self.odeblock.nfe = 0
        return nfe


### my
class ODEBlockSiLUSigmoid(nn.Module):
    def __init__(
        self, n_filters=64, tol=1e-3, method="dopri5", adjoint=False, t1=1, norm="group"
    ):
        super(ODEBlockSiLUSigmoid, self).__init__()
        self.odefunc = ODEfuncSiLUSigmoid(n_filters, norm=norm)
        self.t1 = t1
        self.tol = tol
        self.method = method
        self.odeint = odeint_adjoint if adjoint else odeint
        self.return_last_only = True

    def forward(self, x):
        if self.integration_time is None:
            return x

        self.integration_time = self.integration_time.type_as(x)
        out = self.odeint(
            self.odefunc,
            x,
            self.integration_time,
            method=self.method,
            rtol=self.tol,
            atol=self.tol,
        )
        if self.return_last_only:
            out = out[-1]  # dynamics @ t=t1
        return out

    @property
    def nfe(self):
        return self.odefunc.nfe

    @nfe.setter
    def nfe(self, value):
        self.odefunc.nfe = value

    @property
    def t1(self):
        return self.integration_time[1]

    @t1.setter
    def t1(self, value):
        if isinstance(value, (int, float)):
            if value == 0:
                self.integration_time = None
            else:
                self.integration_time = torch.tensor([0, value], dtype=torch.float32)

        elif isinstance(value, (list, tuple, torch.Tensor)):
            if isinstance(value, tuple):
                value = list(value)
            if isinstance(value, torch.Tensor):
                value = value.tolist()
            if value[0] != 0:
                print(value[0])
                value = [
                    0,
                ] + value

            self.integration_time = torch.tensor(value, dtype=torch.float32)
        else:
            raise ValueError("Argument must be a scalar, a list, or a tensor")


### my
class ODEfuncSiLUSigmoid(nn.Module):
    def __init__(self, dim, norm="group"):
        super(ODEfuncSiLUSigmoid, self).__init__()
        norm = normalization(norm)
        self.norm1 = norm(dim)
        self.silusigmoid = SiLUSigmoid()
        self.conv1 = ConcatConv2d(dim, dim, kernel_size=3, stride=1, padding=1)
        self.norm2 = norm(dim)
        self.conv2 = ConcatConv2d(dim, dim, kernel_size=3, stride=1, padding=1)
        self.norm3 = norm(dim)
        self.nfe = 0

    def forward(self, t, x):
        self.nfe += 1
        out = self.norm1(x)
        out = self.silusigmoid(t, out)
        out = self.conv1(t, out)
        out = self.norm2(out)
        out = self.silusigmoid(t, out)
        out = self.conv2(t, out)
        out = self.norm3(out)
        return out


class SiLUSigmoid(nn.Module):
    def __init__(self):
        super(SiLUSigmoid, self).__init__()
        self.sigmoid = nn.Sigmoid()

    def forward(self, t, x):
        ### return torch.mul(torch.pow(x, t), self.sigmoid(x))
        return torch.mul(torch.pow(x, F.relu(t)), self.sigmoid(x))
        # return torch.mul(x, self.sigmoid(x))


### my
class ODENetVaryingGroup(nn.Module):
    def __init__(
        self,
        in_ch,
        out=10,
        n_filters=64,
        downsample="residual",
        method="dopri5",
        tol=1e-3,
        adjoint=False,
        t1=1,
        dropout=0,
        norm="group",
    ):
        super(ODENetVaryingGroup, self).__init__()

        common = dict(out_ch=n_filters, norm=norm)
        if downsample == "residual":
            self.downsample = ResDownsample(in_ch, **common)
        elif downsample == "convolution":
            self.downsample = ConvDownsample(in_ch, **common)
        elif downsample == "minimal":
            self.downsample = MinimalConvDownsample(in_ch, **common)
        elif downsample == "one-shot":
            self.downsample = OneShotDownsample(in_ch, **common)
        elif downsample == "ode":
            self.downsample = ODEDownsample(
                in_ch, adjoint=adjoint, t1=t1, tol=tol, method=method, **common
            )
        elif downsample == "ode2":
            self.downsample = ODEDownsample2(
                in_ch, adjoint=adjoint, t1=t1, tol=tol, method=method, **common
            )

        self.odeblock = ODEBlockVaryingGroup(
            n_filters=n_filters,
            tol=tol,
            adjoint=adjoint,
            t1=t1,
            method=method,
            norm=norm,
        )
        self.classifier = FCClassifier(
            in_ch=n_filters, out=out, dropout=dropout, norm=norm
        )

    def forward(self, x):
        out = []
        x = self.downsample(x)
        if isinstance(x, (tuple, list)):
            (
                f,
                x,
            ) = x  # first elements are features, second is output to continue the forward
            if isinstance(
                self.classifier.module[-1], nn.Sequential
            ):  # no classification to be performed, apply GAP and return
                f = torch.stack(
                    [fi.mean(-1).mean(-1) for fi in f]
                )  # global avg pooling
            else:  # we want to apply the classifier to all features
                f = torch.stack([self.classifier(fi) for fi in f])
            out.append(f)

        x = self.odeblock(x)
        if x.dim() > 4:
            x = torch.stack([self.classifier(xi) for xi in x])
        else:
            x = self.classifier(x)

        out.append(x)
        out = torch.cat(out)
        return out

    def to_features_extractor(self, keep_pool=True):  # ugly hack
        if isinstance(self.downsample, (ODEDownsample, ODEDownsample2)):
            self.downsample.odeblock.return_last_only = False
        self.odeblock.return_last_only = False  # returns dynamic @ multiple timestamps
        if keep_pool:
            # remove last classification layer but maintain norm, relu and global avg pooling
            self.classifier.module[-1] = nn.Sequential()
        else:
            self.classifier = nn.Sequential(
                *list(self.classifier.module.children())[:2]
            )

    def nfe(self, reset=False):
        nfe = self.odeblock.nfe
        if reset:
            self.odeblock.nfe = 0
        return nfe


### my
class ODEBlockVaryingGroup(nn.Module):
    def __init__(
        self, n_filters=64, tol=1e-3, method="dopri5", adjoint=False, t1=1, norm="group"
    ):
        super(ODEBlockVaryingGroup, self).__init__()
        self.odefunc = ODEfuncVaryingGroup(n_filters, norm=norm)
        self.t1 = t1
        self.tol = tol
        self.method = method
        self.odeint = odeint_adjoint if adjoint else odeint
        self.return_last_only = True

    def forward(self, x):
        if self.integration_time is None:
            return x

        self.integration_time = self.integration_time.type_as(x)
        out = self.odeint(
            self.odefunc,
            x,
            self.integration_time,
            method=self.method,
            rtol=self.tol,
            atol=self.tol,
        )
        if self.return_last_only:
            out = out[-1]  # dynamics @ t=t1
        return out

    @property
    def nfe(self):
        return self.odefunc.nfe

    @nfe.setter
    def nfe(self, value):
        self.odefunc.nfe = value

    @property
    def t1(self):
        return self.integration_time[1]

    @t1.setter
    def t1(self, value):
        if isinstance(value, (int, float)):
            if value == 0:
                self.integration_time = None
            else:
                self.integration_time = torch.tensor([0, value], dtype=torch.float32)

        elif isinstance(value, (list, tuple, torch.Tensor)):
            if isinstance(value, tuple):
                value = list(value)
            if isinstance(value, torch.Tensor):
                value = value.tolist()
            if value[0] != 0:
                print(value[0])
                value = [
                    0,
                ] + value

            self.integration_time = torch.tensor(value, dtype=torch.float32)
        else:
            raise ValueError("Argument must be a scalar, a list, or a tensor")


### my
class ODEfuncVaryingGroup(nn.Module):
    def __init__(self, dim, norm="group"):
        super(ODEfuncVaryingGroup, self).__init__()
        norm = normalization(norm)
        ### self.norm1 = norm(dim)
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = ConcatConv2d(dim, dim, kernel_size=3, stride=1, padding=1)
        ### self.norm2 = norm(dim)
        self.conv2 = ConcatConv2d(dim, dim, kernel_size=3, stride=1, padding=1)
        ### self.norm3 = norm(dim)
        self.nfe = 0

        ### my
        self.norm1_G8 = nn.GroupNorm(8, dim)
        self.norm1_G16 = nn.GroupNorm(16, dim)
        self.norm1_G32 = nn.GroupNorm(32, dim)
        self.norm1_G64 = nn.GroupNorm(64, dim)

        self.norm2_G8 = nn.GroupNorm(8, dim)
        self.norm2_G16 = nn.GroupNorm(16, dim)
        self.norm2_G32 = nn.GroupNorm(32, dim)
        self.norm2_G64 = nn.GroupNorm(64, dim)

        self.norm3_G8 = nn.GroupNorm(8, dim)
        self.norm3_G16 = nn.GroupNorm(16, dim)
        self.norm3_G32 = nn.GroupNorm(32, dim)
        self.norm3_G64 = nn.GroupNorm(64, dim)
        ###

    def forward(self, t, x):
        self.nfe += 1

        # norm1
        if t < 0.25:
            out = self.norm1_G8(x)
        elif (0.25 <= t) & (t < 0.50):
            out = self.norm1_G16(x)
        elif (0.50 <= t) & (t < 0.75):
            out = self.norm1_G32(x)
        elif 0.75 <= t:
            out = self.norm1_G64(x)

        out = self.relu(out)
        out = self.conv1(t, out)

        # norm2
        if t < 0.25:
            out = self.norm2_G8(out)
        elif (0.25 <= t) & (t < 0.50):
            out = self.norm2_G16(out)
        elif (0.50 <= t) & (t < 0.75):
            out = self.norm2_G32(out)
        elif 0.75 <= t:
            out = self.norm2_G64(out)

        out = self.relu(out)
        out = self.conv2(t, out)

        # norm3
        if t < 0.25:
            out = self.norm3_G8(out)
        elif (0.25 <= t) & (t < 0.50):
            out = self.norm3_G16(out)
        elif (0.50 <= t) & (t < 0.75):
            out = self.norm3_G32(out)
        elif 0.75 <= t:
            out = self.norm3_G64(out)

        return out


### my
class ODENetVaryingNorm(nn.Module):
    def __init__(
        self,
        in_ch,
        out=10,
        n_filters=64,
        downsample="residual",
        method="dopri5",
        tol=1e-3,
        adjoint=False,
        t1=1,
        dropout=0,
        norm="group",
    ):
        super(ODENetVaryingNorm, self).__init__()

        common = dict(out_ch=n_filters, norm=norm)
        if downsample == "residual":
            self.downsample = ResDownsample(in_ch, **common)
        elif downsample == "convolution":
            self.downsample = ConvDownsample(in_ch, **common)
        elif downsample == "minimal":
            self.downsample = MinimalConvDownsample(in_ch, **common)
        elif downsample == "one-shot":
            self.downsample = OneShotDownsample(in_ch, **common)
        elif downsample == "ode":
            self.downsample = ODEDownsample(
                in_ch, adjoint=adjoint, t1=t1, tol=tol, method=method, **common
            )
        elif downsample == "ode2":
            self.downsample = ODEDownsample2(
                in_ch, adjoint=adjoint, t1=t1, tol=tol, method=method, **common
            )

        self.odeblock = ODEBlockVaryingNorm(
            n_filters=n_filters,
            tol=tol,
            adjoint=adjoint,
            t1=t1,
            method=method,
            norm=norm,
        )
        self.classifier = FCClassifier(
            in_ch=n_filters, out=out, dropout=dropout, norm=norm
        )

    def forward(self, x):
        out = []
        x = self.downsample(x)
        if isinstance(x, (tuple, list)):
            (
                f,
                x,
            ) = x  # first elements are features, second is output to continue the forward
            if isinstance(
                self.classifier.module[-1], nn.Sequential
            ):  # no classification to be performed, apply GAP and return
                f = torch.stack(
                    [fi.mean(-1).mean(-1) for fi in f]
                )  # global avg pooling
            else:  # we want to apply the classifier to all features
                f = torch.stack([self.classifier(fi) for fi in f])
            out.append(f)

        x = self.odeblock(x)
        if x.dim() > 4:
            x = torch.stack([self.classifier(xi) for xi in x])
        else:
            x = self.classifier(x)

        out.append(x)
        out = torch.cat(out)
        return out

    def to_features_extractor(self, keep_pool=True):  # ugly hack
        if isinstance(self.downsample, (ODEDownsample, ODEDownsample2)):
            self.downsample.odeblock.return_last_only = False
        self.odeblock.return_last_only = False  # returns dynamic @ multiple timestamps
        if keep_pool:
            # remove last classification layer but maintain norm, relu and global avg pooling
            self.classifier.module[-1] = nn.Sequential()
        else:
            self.classifier = nn.Sequential(
                *list(self.classifier.module.children())[:2]
            )

    def nfe(self, reset=False):
        nfe = self.odeblock.nfe
        if reset:
            self.odeblock.nfe = 0
        return nfe


### my
class ODEBlockVaryingNorm(nn.Module):
    def __init__(
        self, n_filters=64, tol=1e-3, method="dopri5", adjoint=False, t1=1, norm="group"
    ):
        super(ODEBlockVaryingNorm, self).__init__()
        self.odefunc = ODEfuncVaryingNorm(n_filters, norm=norm)
        self.t1 = t1
        self.tol = tol
        self.method = method
        self.odeint = odeint_adjoint if adjoint else odeint
        self.return_last_only = True

    def forward(self, x):
        if self.integration_time is None:
            return x

        self.integration_time = self.integration_time.type_as(x)
        out = self.odeint(
            self.odefunc,
            x,
            self.integration_time,
            method=self.method,
            rtol=self.tol,
            atol=self.tol,
        )
        if self.return_last_only:
            out = out[-1]  # dynamics @ t=t1
        return out

    @property
    def nfe(self):
        return self.odefunc.nfe

    @nfe.setter
    def nfe(self, value):
        self.odefunc.nfe = value

    @property
    def t1(self):
        return self.integration_time[1]

    @t1.setter
    def t1(self, value):
        if isinstance(value, (int, float)):
            if value == 0:
                self.integration_time = None
            else:
                self.integration_time = torch.tensor([0, value], dtype=torch.float32)

        elif isinstance(value, (list, tuple, torch.Tensor)):
            if isinstance(value, tuple):
                value = list(value)
            if isinstance(value, torch.Tensor):
                value = value.tolist()
            if value[0] != 0:
                print(value[0])
                value = [
                    0,
                ] + value

            self.integration_time = torch.tensor(value, dtype=torch.float32)
        else:
            raise ValueError("Argument must be a scalar, a list, or a tensor")


### my
class ODEfuncVaryingNorm(nn.Module):
    def __init__(self, dim, norm="group"):
        super(ODEfuncVaryingNorm, self).__init__()
        ### norm = normalization(norm)
        norm = GroupNormVaryingNorm
        self.norm1 = norm(32, dim)
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = ConcatConv2d(dim, dim, kernel_size=3, stride=1, padding=1)
        self.norm2 = norm(32, dim)
        self.conv2 = ConcatConv2d(dim, dim, kernel_size=3, stride=1, padding=1)
        self.norm3 = norm(32, dim)
        self.nfe = 0

    def forward(self, t, x):
        self.nfe += 1
        out = self.norm1(t, x)
        out = self.relu(out)
        out = self.conv1(t, out)
        out = self.norm2(t, out)
        out = self.relu(out)
        out = self.conv2(t, out)
        out = self.norm3(t, out)

        return out


class GroupNormVaryingNorm(nn.Module):
    def __init__(self, num_groups, num_channels, eps=1e-5, affine=True, **kwargs):
        super().__init__(**kwargs)
        assert num_channels % num_groups == 0
        self.num_groups = num_groups
        self.num_channels = num_channels
        self.eps = eps
        self.affine = affine
        if self.affine:
            self.weight = nn.Parameter(torch.ones(num_channels))
            self.bias = nn.Parameter(torch.zeros(num_channels))
        else:
            self.weight = None
            self.bias = None
        # self.reset_parameters()

    # def reset_parameters(self):
    #     if self.affine:
    #         ones_(self.weight)
    #         zeros_(self.bias)

    ### def forward(self, x):
    def forward(self, t, x):
        N, C, H, W = x.shape
        assert C == self.num_channels

        x = x.reshape(N, self.num_groups, -1)
        mean = x.mean(axis=2, keepdims=True)
        ### var = (x * x).mean(axis=2, keepdims=True) - mean * mean
        var_my = (x - mean).abs().pow(1.0 + t).sum(axis=2, keepdims=True).pow(
            2.0 / (1.0 + t)
        ) / (C * H * W / self.num_groups)

        x = (x - mean) / torch.sqrt(var_my + self.eps)
        x = x.reshape(N, C, H, W)
        if self.affine:
            x = self.weight.reshape(1, -1, 1, 1) * x + self.bias.reshape(1, -1, 1, 1)

        return x


### my
class ODENetVaryingDrop(nn.Module):
    def __init__(
        self,
        in_ch,
        out=10,
        n_filters=64,
        downsample="residual",
        method="dopri5",
        tol=1e-3,
        adjoint=False,
        t1=1,
        dropout=0,
        norm="group",
    ):
        super(ODENetVaryingDrop, self).__init__()

        common = dict(out_ch=n_filters, norm=norm)
        if downsample == "residual":
            self.downsample = ResDownsample(in_ch, **common)
        elif downsample == "convolution":
            self.downsample = ConvDownsample(in_ch, **common)
        elif downsample == "minimal":
            self.downsample = MinimalConvDownsample(in_ch, **common)
        elif downsample == "one-shot":
            self.downsample = OneShotDownsample(in_ch, **common)
        elif downsample == "ode":
            self.downsample = ODEDownsample(
                in_ch, adjoint=adjoint, t1=t1, tol=tol, method=method, **common
            )
        elif downsample == "ode2":
            self.downsample = ODEDownsample2(
                in_ch, adjoint=adjoint, t1=t1, tol=tol, method=method, **common
            )

        self.odeblock = ODEBlockVaryingDrop(
            n_filters=n_filters,
            tol=tol,
            adjoint=adjoint,
            t1=t1,
            method=method,
            norm=norm,
        )
        self.classifier = FCClassifier(
            in_ch=n_filters, out=out, dropout=dropout, norm=norm
        )

    def forward(self, x):
        out = []
        x = self.downsample(x)
        if isinstance(x, (tuple, list)):
            (
                f,
                x,
            ) = x  # first elements are features, second is output to continue the forward
            if isinstance(
                self.classifier.module[-1], nn.Sequential
            ):  # no classification to be performed, apply GAP and return
                f = torch.stack(
                    [fi.mean(-1).mean(-1) for fi in f]
                )  # global avg pooling
            else:  # we want to apply the classifier to all features
                f = torch.stack([self.classifier(fi) for fi in f])
            out.append(f)

        x = self.odeblock(x)
        if x.dim() > 4:
            x = torch.stack([self.classifier(xi) for xi in x])
        else:
            x = self.classifier(x)

        out.append(x)
        out = torch.cat(out)
        return out

    def to_features_extractor(self, keep_pool=True):  # ugly hack
        if isinstance(self.downsample, (ODEDownsample, ODEDownsample2)):
            self.downsample.odeblock.return_last_only = False
        self.odeblock.return_last_only = False  # returns dynamic @ multiple timestamps
        if keep_pool:
            # remove last classification layer but maintain norm, relu and global avg pooling
            self.classifier.module[-1] = nn.Sequential()
        else:
            self.classifier = nn.Sequential(
                *list(self.classifier.module.children())[:2]
            )

    def nfe(self, reset=False):
        nfe = self.odeblock.nfe
        if reset:
            self.odeblock.nfe = 0
        return nfe


### my
class ODEBlockVaryingDrop(nn.Module):
    def __init__(
        self, n_filters=64, tol=1e-3, method="dopri5", adjoint=False, t1=1, norm="group"
    ):
        super(ODEBlockVaryingDrop, self).__init__()
        self.odefunc = ODEfuncVaryingDrop(n_filters, norm=norm)
        self.t1 = t1
        self.tol = tol
        self.method = method
        self.odeint = odeint_adjoint if adjoint else odeint
        self.return_last_only = True

    def forward(self, x):
        if self.integration_time is None:
            return x

        self.integration_time = self.integration_time.type_as(x)
        out = self.odeint(
            self.odefunc,
            x,
            self.integration_time,
            method=self.method,
            rtol=self.tol,
            atol=self.tol,
        )
        if self.return_last_only:
            out = out[-1]  # dynamics @ t=t1
        return out

    @property
    def nfe(self):
        return self.odefunc.nfe

    @nfe.setter
    def nfe(self, value):
        self.odefunc.nfe = value

    @property
    def t1(self):
        return self.integration_time[1]

    @t1.setter
    def t1(self, value):
        if isinstance(value, (int, float)):
            if value == 0:
                self.integration_time = None
            else:
                self.integration_time = torch.tensor([0, value], dtype=torch.float32)

        elif isinstance(value, (list, tuple, torch.Tensor)):
            if isinstance(value, tuple):
                value = list(value)
            if isinstance(value, torch.Tensor):
                value = value.tolist()
            if value[0] != 0:
                print(value[0])
                value = [
                    0,
                ] + value

            self.integration_time = torch.tensor(value, dtype=torch.float32)
        else:
            raise ValueError("Argument must be a scalar, a list, or a tensor")


### my
class ODEfuncVaryingDrop(nn.Module):
    def __init__(self, dim, norm="group"):
        super(ODEfuncVaryingDrop, self).__init__()
        norm = normalization(norm)
        self.norm1 = norm(dim)
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = ConcatConv2d(dim, dim, kernel_size=3, stride=1, padding=1)
        self.norm2 = norm(dim)
        self.conv2 = ConcatConv2d(dim, dim, kernel_size=3, stride=1, padding=1)
        self.norm3 = norm(dim)
        self.nfe = 0

        ### my
        self.max_drop_prob = 0.1
        ###

    def forward(self, t, x):
        self.nfe += 1
        out = self.norm1(x)
        out = self.relu(out)
        out = self.conv1(t, out)
        out = self.norm2(out)
        out = self.relu(out)
        out = self.conv2(t, out)
        out = self.norm3(out)

        ### my
        out = F.dropout(out, max(t.item() * self.max_drop_prob, 0), self.training)
        ###

        return out


### my
class ODENetFixedDrop(nn.Module):
    def __init__(
        self,
        in_ch,
        out=10,
        n_filters=64,
        downsample="residual",
        method="dopri5",
        tol=1e-3,
        adjoint=False,
        t1=1,
        dropout=0,
        norm="group",
    ):
        super(ODENetFixedDrop, self).__init__()

        common = dict(out_ch=n_filters, norm=norm)
        if downsample == "residual":
            self.downsample = ResDownsample(in_ch, **common)
        elif downsample == "convolution":
            self.downsample = ConvDownsample(in_ch, **common)
        elif downsample == "minimal":
            self.downsample = MinimalConvDownsample(in_ch, **common)
        elif downsample == "one-shot":
            self.downsample = OneShotDownsample(in_ch, **common)
        elif downsample == "ode":
            self.downsample = ODEDownsample(
                in_ch, adjoint=adjoint, t1=t1, tol=tol, method=method, **common
            )
        elif downsample == "ode2":
            self.downsample = ODEDownsample2(
                in_ch, adjoint=adjoint, t1=t1, tol=tol, method=method, **common
            )

        self.odeblock = ODEBlockFixedDrop(
            n_filters=n_filters,
            tol=tol,
            adjoint=adjoint,
            t1=t1,
            method=method,
            norm=norm,
        )
        self.classifier = FCClassifier(
            in_ch=n_filters, out=out, dropout=dropout, norm=norm
        )

    def forward(self, x):
        out = []
        x = self.downsample(x)
        if isinstance(x, (tuple, list)):
            (
                f,
                x,
            ) = x  # first elements are features, second is output to continue the forward
            if isinstance(
                self.classifier.module[-1], nn.Sequential
            ):  # no classification to be performed, apply GAP and return
                f = torch.stack(
                    [fi.mean(-1).mean(-1) for fi in f]
                )  # global avg pooling
            else:  # we want to apply the classifier to all features
                f = torch.stack([self.classifier(fi) for fi in f])
            out.append(f)

        x = self.odeblock(x)
        if x.dim() > 4:
            x = torch.stack([self.classifier(xi) for xi in x])
        else:
            x = self.classifier(x)

        out.append(x)
        out = torch.cat(out)
        return out

    def to_features_extractor(self, keep_pool=True):  # ugly hack
        if isinstance(self.downsample, (ODEDownsample, ODEDownsample2)):
            self.downsample.odeblock.return_last_only = False
        self.odeblock.return_last_only = False  # returns dynamic @ multiple timestamps
        if keep_pool:
            # remove last classification layer but maintain norm, relu and global avg pooling
            self.classifier.module[-1] = nn.Sequential()
        else:
            self.classifier = nn.Sequential(
                *list(self.classifier.module.children())[:2]
            )

    def nfe(self, reset=False):
        nfe = self.odeblock.nfe
        if reset:
            self.odeblock.nfe = 0
        return nfe


### my
class ODEBlockFixedDrop(nn.Module):
    def __init__(
        self, n_filters=64, tol=1e-3, method="dopri5", adjoint=False, t1=1, norm="group"
    ):
        super(ODEBlockFixedDrop, self).__init__()
        self.odefunc = ODEfuncFixedDrop(n_filters, norm=norm)
        self.t1 = t1
        self.tol = tol
        self.method = method
        self.odeint = odeint_adjoint if adjoint else odeint
        self.return_last_only = True

    def forward(self, x):
        if self.integration_time is None:
            return x

        self.integration_time = self.integration_time.type_as(x)
        out = self.odeint(
            self.odefunc,
            x,
            self.integration_time,
            method=self.method,
            rtol=self.tol,
            atol=self.tol,
        )
        if self.return_last_only:
            out = out[-1]  # dynamics @ t=t1
        return out

    @property
    def nfe(self):
        return self.odefunc.nfe

    @nfe.setter
    def nfe(self, value):
        self.odefunc.nfe = value

    @property
    def t1(self):
        return self.integration_time[1]

    @t1.setter
    def t1(self, value):
        if isinstance(value, (int, float)):
            if value == 0:
                self.integration_time = None
            else:
                self.integration_time = torch.tensor([0, value], dtype=torch.float32)

        elif isinstance(value, (list, tuple, torch.Tensor)):
            if isinstance(value, tuple):
                value = list(value)
            if isinstance(value, torch.Tensor):
                value = value.tolist()
            if value[0] != 0:
                print(value[0])
                value = [
                    0,
                ] + value

            self.integration_time = torch.tensor(value, dtype=torch.float32)
        else:
            raise ValueError("Argument must be a scalar, a list, or a tensor")


### my
class ODEfuncFixedDrop(nn.Module):
    def __init__(self, dim, norm="group"):
        super(ODEfuncFixedDrop, self).__init__()
        norm = normalization(norm)
        self.norm1 = norm(dim)
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = ConcatConv2d(dim, dim, kernel_size=3, stride=1, padding=1)
        self.norm2 = norm(dim)
        self.conv2 = ConcatConv2d(dim, dim, kernel_size=3, stride=1, padding=1)
        self.norm3 = norm(dim)
        self.nfe = 0

        ### my
        self.max_drop_prob = 0.1
        ###

    def forward(self, t, x):
        self.nfe += 1
        out = self.norm1(x)
        out = self.relu(out)
        out = self.conv1(t, out)
        out = self.norm2(out)
        out = self.relu(out)
        out = self.conv2(t, out)
        out = self.norm3(out)

        ### my
        out = F.dropout(out, self.max_drop_prob, self.training)
        ###

        return out


### my
class ODENetSoftplus(nn.Module):
    def __init__(
        self,
        in_ch,
        out=10,
        n_filters=64,
        downsample="residual",
        method="dopri5",
        tol=1e-3,
        adjoint=False,
        t1=1,
        dropout=0,
        norm="group",
    ):
        super(ODENetSoftplus, self).__init__()

        common = dict(out_ch=n_filters, norm=norm)
        if downsample == "residual":
            self.downsample = ResDownsample(in_ch, **common)
        elif downsample == "convolution":
            self.downsample = ConvDownsample(in_ch, **common)
        elif downsample == "minimal":
            self.downsample = MinimalConvDownsample(in_ch, **common)
        elif downsample == "one-shot":
            self.downsample = OneShotDownsample(in_ch, **common)
        elif downsample == "ode":
            self.downsample = ODEDownsample(
                in_ch, adjoint=adjoint, t1=t1, tol=tol, method=method, **common
            )
        elif downsample == "ode2":
            self.downsample = ODEDownsample2(
                in_ch, adjoint=adjoint, t1=t1, tol=tol, method=method, **common
            )

        self.odeblock = ODEBlockSoftplus(
            n_filters=n_filters,
            tol=tol,
            adjoint=adjoint,
            t1=t1,
            method=method,
            norm=norm,
        )
        self.classifier = FCClassifier(
            in_ch=n_filters, out=out, dropout=dropout, norm=norm
        )

    def forward(self, x):
        out = []
        x = self.downsample(x)
        if isinstance(x, (tuple, list)):
            (
                f,
                x,
            ) = x  # first elements are features, second is output to continue the forward
            if isinstance(
                self.classifier.module[-1], nn.Sequential
            ):  # no classification to be performed, apply GAP and return
                f = torch.stack(
                    [fi.mean(-1).mean(-1) for fi in f]
                )  # global avg pooling
            else:  # we want to apply the classifier to all features
                f = torch.stack([self.classifier(fi) for fi in f])
            out.append(f)

        x = self.odeblock(x)
        if x.dim() > 4:
            x = torch.stack([self.classifier(xi) for xi in x])
        else:
            x = self.classifier(x)

        out.append(x)
        out = torch.cat(out)
        return out

    def to_features_extractor(self, keep_pool=True):  # ugly hack
        if isinstance(self.downsample, (ODEDownsample, ODEDownsample2)):
            self.downsample.odeblock.return_last_only = False
        self.odeblock.return_last_only = False  # returns dynamic @ multiple timestamps
        if keep_pool:
            # remove last classification layer but maintain norm, relu and global avg pooling
            self.classifier.module[-1] = nn.Sequential()
        else:
            self.classifier = nn.Sequential(
                *list(self.classifier.module.children())[:2]
            )

    def nfe(self, reset=False):
        nfe = self.odeblock.nfe
        if reset:
            self.odeblock.nfe = 0
        return nfe


### my
class ODEBlockSoftplus(nn.Module):
    def __init__(
        self, n_filters=64, tol=1e-3, method="dopri5", adjoint=False, t1=1, norm="group"
    ):
        super(ODEBlockSoftplus, self).__init__()
        self.odefunc = ODEfuncSoftplus(n_filters, norm=norm)
        self.t1 = t1
        self.tol = tol
        self.method = method
        self.odeint = odeint_adjoint if adjoint else odeint
        self.return_last_only = True

    def forward(self, x):
        if self.integration_time is None:
            return x

        self.integration_time = self.integration_time.type_as(x)
        out = self.odeint(
            self.odefunc,
            x,
            self.integration_time,
            method=self.method,
            rtol=self.tol,
            atol=self.tol,
        )
        if self.return_last_only:
            out = out[-1]  # dynamics @ t=t1
        return out

    @property
    def nfe(self):
        return self.odefunc.nfe

    @nfe.setter
    def nfe(self, value):
        self.odefunc.nfe = value

    @property
    def t1(self):
        return self.integration_time[1]

    @t1.setter
    def t1(self, value):
        if isinstance(value, (int, float)):
            if value == 0:
                self.integration_time = None
            else:
                self.integration_time = torch.tensor([0, value], dtype=torch.float32)

        elif isinstance(value, (list, tuple, torch.Tensor)):
            if isinstance(value, tuple):
                value = list(value)
            if isinstance(value, torch.Tensor):
                value = value.tolist()
            if value[0] != 0:
                print(value[0])
                value = [
                    0,
                ] + value

            self.integration_time = torch.tensor(value, dtype=torch.float32)
        else:
            raise ValueError("Argument must be a scalar, a list, or a tensor")


### my
class ODEfuncSoftplus(nn.Module):
    def __init__(self, dim, norm="group"):
        super(ODEfuncSoftplus, self).__init__()
        norm = normalization(norm)
        self.norm1 = norm(dim)
        self.softplus = nn.Softplus()
        self.conv1 = ConcatConv2d(dim, dim, kernel_size=3, stride=1, padding=1)
        self.norm2 = norm(dim)
        self.conv2 = ConcatConv2d(dim, dim, kernel_size=3, stride=1, padding=1)
        self.norm3 = norm(dim)
        self.nfe = 0

    def forward(self, t, x):
        self.nfe += 1
        out = self.norm1(x)
        out = self.softplus(out)
        out = self.conv1(t, out)
        out = self.norm2(out)
        out = self.softplus(out)
        out = self.conv2(t, out)
        out = self.norm3(out)
        return out


### my
class ODENetReLUSoftplus(nn.Module):
    def __init__(
        self,
        in_ch,
        out=10,
        n_filters=64,
        downsample="residual",
        method="dopri5",
        tol=1e-3,
        adjoint=False,
        t1=1,
        dropout=0,
        norm="group",
    ):
        super(ODENetReLUSoftplus, self).__init__()

        common = dict(out_ch=n_filters, norm=norm)
        if downsample == "residual":
            self.downsample = ResDownsample(in_ch, **common)
        elif downsample == "convolution":
            self.downsample = ConvDownsample(in_ch, **common)
        elif downsample == "minimal":
            self.downsample = MinimalConvDownsample(in_ch, **common)
        elif downsample == "one-shot":
            self.downsample = OneShotDownsample(in_ch, **common)
        elif downsample == "ode":
            self.downsample = ODEDownsample(
                in_ch, adjoint=adjoint, t1=t1, tol=tol, method=method, **common
            )
        elif downsample == "ode2":
            self.downsample = ODEDownsample2(
                in_ch, adjoint=adjoint, t1=t1, tol=tol, method=method, **common
            )

        self.odeblock = ODEBlockReLUSoftplus(
            n_filters=n_filters,
            tol=tol,
            adjoint=adjoint,
            t1=t1,
            method=method,
            norm=norm,
        )
        self.classifier = FCClassifier(
            in_ch=n_filters, out=out, dropout=dropout, norm=norm
        )

    def forward(self, x):
        out = []
        x = self.downsample(x)
        if isinstance(x, (tuple, list)):
            (
                f,
                x,
            ) = x  # first elements are features, second is output to continue the forward
            if isinstance(
                self.classifier.module[-1], nn.Sequential
            ):  # no classification to be performed, apply GAP and return
                f = torch.stack(
                    [fi.mean(-1).mean(-1) for fi in f]
                )  # global avg pooling
            else:  # we want to apply the classifier to all features
                f = torch.stack([self.classifier(fi) for fi in f])
            out.append(f)

        x = self.odeblock(x)
        if x.dim() > 4:
            x = torch.stack([self.classifier(xi) for xi in x])
        else:
            x = self.classifier(x)

        out.append(x)
        out = torch.cat(out)
        return out

    def to_features_extractor(self, keep_pool=True):  # ugly hack
        if isinstance(self.downsample, (ODEDownsample, ODEDownsample2)):
            self.downsample.odeblock.return_last_only = False
        self.odeblock.return_last_only = False  # returns dynamic @ multiple timestamps
        if keep_pool:
            # remove last classification layer but maintain norm, relu and global avg pooling
            self.classifier.module[-1] = nn.Sequential()
        else:
            self.classifier = nn.Sequential(
                *list(self.classifier.module.children())[:2]
            )

    def nfe(self, reset=False):
        nfe = self.odeblock.nfe
        if reset:
            self.odeblock.nfe = 0
        return nfe


### my
class ODEBlockReLUSoftplus(nn.Module):
    def __init__(
        self, n_filters=64, tol=1e-3, method="dopri5", adjoint=False, t1=1, norm="group"
    ):
        super(ODEBlockReLUSoftplus, self).__init__()
        self.odefunc = ODEfuncReLUSoftplus(n_filters, norm=norm)
        self.t1 = t1
        self.tol = tol
        self.method = method
        self.odeint = odeint_adjoint if adjoint else odeint
        self.return_last_only = True

    def forward(self, x):
        if self.integration_time is None:
            return x

        self.integration_time = self.integration_time.type_as(x)
        out = self.odeint(
            self.odefunc,
            x,
            self.integration_time,
            method=self.method,
            rtol=self.tol,
            atol=self.tol,
        )
        if self.return_last_only:
            out = out[-1]  # dynamics @ t=t1
        return out

    @property
    def nfe(self):
        return self.odefunc.nfe

    @nfe.setter
    def nfe(self, value):
        self.odefunc.nfe = value

    @property
    def t1(self):
        return self.integration_time[1]

    @t1.setter
    def t1(self, value):
        if isinstance(value, (int, float)):
            if value == 0:
                self.integration_time = None
            else:
                self.integration_time = torch.tensor([0, value], dtype=torch.float32)

        elif isinstance(value, (list, tuple, torch.Tensor)):
            if isinstance(value, tuple):
                value = list(value)
            if isinstance(value, torch.Tensor):
                value = value.tolist()
            if value[0] != 0:
                print(value[0])
                value = [
                    0,
                ] + value

            self.integration_time = torch.tensor(value, dtype=torch.float32)
        else:
            raise ValueError("Argument must be a scalar, a list, or a tensor")


### my
class ODEfuncReLUSoftplus(nn.Module):
    def __init__(self, dim, norm="group"):
        super(ODEfuncReLUSoftplus, self).__init__()
        norm = normalization(norm)
        self.norm1 = norm(dim)
        ### self.softplus = nn.ReLUSoftplus()
        self.conv1 = ConcatConv2d(dim, dim, kernel_size=3, stride=1, padding=1)
        self.norm2 = norm(dim)
        self.conv2 = ConcatConv2d(dim, dim, kernel_size=3, stride=1, padding=1)
        self.norm3 = norm(dim)
        self.nfe = 0
        ###
        self.eps = 1e-5

    def forward(self, t, x):
        self.nfe += 1
        out = self.norm1(x)
        ### out = self.softplus(out)
        out = F.softplus(out, beta=1.0 / max(t.item(), self.eps))
        out = self.conv1(t, out)
        out = self.norm2(out)
        ### out = self.softplus(out)
        out = F.softplus(out, beta=1.0 / max(t.item(), self.eps))
        out = self.conv2(t, out)
        out = self.norm3(out)
        return out


### my
class ODENetELU(nn.Module):
    def __init__(
        self,
        in_ch,
        out=10,
        n_filters=64,
        downsample="residual",
        method="dopri5",
        tol=1e-3,
        adjoint=False,
        t1=1,
        dropout=0,
        norm="group",
    ):
        super(ODENetELU, self).__init__()

        common = dict(out_ch=n_filters, norm=norm)
        if downsample == "residual":
            self.downsample = ResDownsample(in_ch, **common)
        elif downsample == "convolution":
            self.downsample = ConvDownsample(in_ch, **common)
        elif downsample == "minimal":
            self.downsample = MinimalConvDownsample(in_ch, **common)
        elif downsample == "one-shot":
            self.downsample = OneShotDownsample(in_ch, **common)
        elif downsample == "ode":
            self.downsample = ODEDownsample(
                in_ch, adjoint=adjoint, t1=t1, tol=tol, method=method, **common
            )
        elif downsample == "ode2":
            self.downsample = ODEDownsample2(
                in_ch, adjoint=adjoint, t1=t1, tol=tol, method=method, **common
            )

        self.odeblock = ODEBlockELU(
            n_filters=n_filters,
            tol=tol,
            adjoint=adjoint,
            t1=t1,
            method=method,
            norm=norm,
        )
        self.classifier = FCClassifier(
            in_ch=n_filters, out=out, dropout=dropout, norm=norm
        )

    def forward(self, x):
        out = []
        x = self.downsample(x)
        if isinstance(x, (tuple, list)):
            (
                f,
                x,
            ) = x  # first elements are features, second is output to continue the forward
            if isinstance(
                self.classifier.module[-1], nn.Sequential
            ):  # no classification to be performed, apply GAP and return
                f = torch.stack(
                    [fi.mean(-1).mean(-1) for fi in f]
                )  # global avg pooling
            else:  # we want to apply the classifier to all features
                f = torch.stack([self.classifier(fi) for fi in f])
            out.append(f)

        x = self.odeblock(x)
        if x.dim() > 4:
            x = torch.stack([self.classifier(xi) for xi in x])
        else:
            x = self.classifier(x)

        out.append(x)
        out = torch.cat(out)
        return out

    def to_features_extractor(self, keep_pool=True):  # ugly hack
        if isinstance(self.downsample, (ODEDownsample, ODEDownsample2)):
            self.downsample.odeblock.return_last_only = False
        self.odeblock.return_last_only = False  # returns dynamic @ multiple timestamps
        if keep_pool:
            # remove last classification layer but maintain norm, relu and global avg pooling
            self.classifier.module[-1] = nn.Sequential()
        else:
            self.classifier = nn.Sequential(
                *list(self.classifier.module.children())[:2]
            )

    def nfe(self, reset=False):
        nfe = self.odeblock.nfe
        if reset:
            self.odeblock.nfe = 0
        return nfe


### my
class ODEBlockELU(nn.Module):
    def __init__(
        self, n_filters=64, tol=1e-3, method="dopri5", adjoint=False, t1=1, norm="group"
    ):
        super(ODEBlockELU, self).__init__()
        self.odefunc = ODEfuncELU(n_filters, norm=norm)
        self.t1 = t1
        self.tol = tol
        self.method = method
        self.odeint = odeint_adjoint if adjoint else odeint
        self.return_last_only = True

    def forward(self, x):
        if self.integration_time is None:
            return x

        self.integration_time = self.integration_time.type_as(x)
        out = self.odeint(
            self.odefunc,
            x,
            self.integration_time,
            method=self.method,
            rtol=self.tol,
            atol=self.tol,
        )
        if self.return_last_only:
            out = out[-1]  # dynamics @ t=t1
        return out

    @property
    def nfe(self):
        return self.odefunc.nfe

    @nfe.setter
    def nfe(self, value):
        self.odefunc.nfe = value

    @property
    def t1(self):
        return self.integration_time[1]

    @t1.setter
    def t1(self, value):
        if isinstance(value, (int, float)):
            if value == 0:
                self.integration_time = None
            else:
                self.integration_time = torch.tensor([0, value], dtype=torch.float32)

        elif isinstance(value, (list, tuple, torch.Tensor)):
            if isinstance(value, tuple):
                value = list(value)
            if isinstance(value, torch.Tensor):
                value = value.tolist()
            if value[0] != 0:
                print(value[0])
                value = [
                    0,
                ] + value

            self.integration_time = torch.tensor(value, dtype=torch.float32)
        else:
            raise ValueError("Argument must be a scalar, a list, or a tensor")


### my
class ODEfuncELU(nn.Module):
    def __init__(self, dim, norm="group"):
        super(ODEfuncELU, self).__init__()
        norm = normalization(norm)
        self.norm1 = norm(dim)
        self.elu = nn.ELU()
        self.conv1 = ConcatConv2d(dim, dim, kernel_size=3, stride=1, padding=1)
        self.norm2 = norm(dim)
        self.conv2 = ConcatConv2d(dim, dim, kernel_size=3, stride=1, padding=1)
        self.norm3 = norm(dim)
        self.nfe = 0

    def forward(self, t, x):
        self.nfe += 1
        out = self.norm1(x)
        out = self.elu(out)
        out = self.conv1(t, out)
        out = self.norm2(out)
        out = self.elu(out)
        out = self.conv2(t, out)
        out = self.norm3(out)
        return out


### my
class ODENetReLUELU(nn.Module):
    def __init__(
        self,
        in_ch,
        out=10,
        n_filters=64,
        downsample="residual",
        method="dopri5",
        tol=1e-3,
        adjoint=False,
        t1=1,
        dropout=0,
        norm="group",
    ):
        super(ODENetReLUELU, self).__init__()

        common = dict(out_ch=n_filters, norm=norm)
        if downsample == "residual":
            self.downsample = ResDownsample(in_ch, **common)
        elif downsample == "convolution":
            self.downsample = ConvDownsample(in_ch, **common)
        elif downsample == "minimal":
            self.downsample = MinimalConvDownsample(in_ch, **common)
        elif downsample == "one-shot":
            self.downsample = OneShotDownsample(in_ch, **common)
        elif downsample == "ode":
            self.downsample = ODEDownsample(
                in_ch, adjoint=adjoint, t1=t1, tol=tol, method=method, **common
            )
        elif downsample == "ode2":
            self.downsample = ODEDownsample2(
                in_ch, adjoint=adjoint, t1=t1, tol=tol, method=method, **common
            )

        self.odeblock = ODEBlockReLUELU(
            n_filters=n_filters,
            tol=tol,
            adjoint=adjoint,
            t1=t1,
            method=method,
            norm=norm,
        )
        self.classifier = FCClassifier(
            in_ch=n_filters, out=out, dropout=dropout, norm=norm
        )

    def forward(self, x):
        out = []
        x = self.downsample(x)
        if isinstance(x, (tuple, list)):
            (
                f,
                x,
            ) = x  # first elements are features, second is output to continue the forward
            if isinstance(
                self.classifier.module[-1], nn.Sequential
            ):  # no classification to be performed, apply GAP and return
                f = torch.stack(
                    [fi.mean(-1).mean(-1) for fi in f]
                )  # global avg pooling
            else:  # we want to apply the classifier to all features
                f = torch.stack([self.classifier(fi) for fi in f])
            out.append(f)

        x = self.odeblock(x)
        if x.dim() > 4:
            x = torch.stack([self.classifier(xi) for xi in x])
        else:
            x = self.classifier(x)

        out.append(x)
        out = torch.cat(out)
        return out

    def to_features_extractor(self, keep_pool=True):  # ugly hack
        if isinstance(self.downsample, (ODEDownsample, ODEDownsample2)):
            self.downsample.odeblock.return_last_only = False
        self.odeblock.return_last_only = False  # returns dynamic @ multiple timestamps
        if keep_pool:
            # remove last classification layer but maintain norm, relu and global avg pooling
            self.classifier.module[-1] = nn.Sequential()
        else:
            self.classifier = nn.Sequential(
                *list(self.classifier.module.children())[:2]
            )

    def nfe(self, reset=False):
        nfe = self.odeblock.nfe
        if reset:
            self.odeblock.nfe = 0
        return nfe


### my
class ODEBlockReLUELU(nn.Module):
    def __init__(
        self, n_filters=64, tol=1e-3, method="dopri5", adjoint=False, t1=1, norm="group"
    ):
        super(ODEBlockReLUELU, self).__init__()
        self.odefunc = ODEfuncReLUELU(n_filters, norm=norm)
        self.t1 = t1
        self.tol = tol
        self.method = method
        self.odeint = odeint_adjoint if adjoint else odeint
        self.return_last_only = True

    def forward(self, x):
        if self.integration_time is None:
            return x

        self.integration_time = self.integration_time.type_as(x)
        out = self.odeint(
            self.odefunc,
            x,
            self.integration_time,
            method=self.method,
            rtol=self.tol,
            atol=self.tol,
        )
        if self.return_last_only:
            out = out[-1]  # dynamics @ t=t1
        return out

    @property
    def nfe(self):
        return self.odefunc.nfe

    @nfe.setter
    def nfe(self, value):
        self.odefunc.nfe = value

    @property
    def t1(self):
        return self.integration_time[1]

    @t1.setter
    def t1(self, value):
        if isinstance(value, (int, float)):
            if value == 0:
                self.integration_time = None
            else:
                self.integration_time = torch.tensor([0, value], dtype=torch.float32)

        elif isinstance(value, (list, tuple, torch.Tensor)):
            if isinstance(value, tuple):
                value = list(value)
            if isinstance(value, torch.Tensor):
                value = value.tolist()
            if value[0] != 0:
                print(value[0])
                value = [
                    0,
                ] + value

            self.integration_time = torch.tensor(value, dtype=torch.float32)
        else:
            raise ValueError("Argument must be a scalar, a list, or a tensor")


### my
class ODEfuncReLUELU(nn.Module):
    def __init__(self, dim, norm="group"):
        super(ODEfuncReLUELU, self).__init__()
        norm = normalization(norm)
        self.norm1 = norm(dim)
        self.conv1 = ConcatConv2d(dim, dim, kernel_size=3, stride=1, padding=1)
        self.norm2 = norm(dim)
        self.conv2 = ConcatConv2d(dim, dim, kernel_size=3, stride=1, padding=1)
        self.norm3 = norm(dim)
        self.nfe = 0

    def forward(self, t, x):
        self.nfe += 1
        out = self.norm1(x)
        ### out = self.softplus(out)
        out = F.elu(out, alpha=max(t.item(), 0.0))
        out = self.conv1(t, out)
        out = self.norm2(out)
        ### out = self.softplus(out)
        out = F.elu(out, alpha=max(t.item(), 0.0))
        out = self.conv2(t, out)
        out = self.norm3(out)
        return out


### my
class ODENetELU(nn.Module):
    def __init__(
        self,
        in_ch,
        out=10,
        n_filters=64,
        downsample="residual",
        method="dopri5",
        tol=1e-3,
        adjoint=False,
        t1=1,
        dropout=0,
        norm="group",
    ):
        super(ODENetELU, self).__init__()

        common = dict(out_ch=n_filters, norm=norm)
        if downsample == "residual":
            self.downsample = ResDownsample(in_ch, **common)
        elif downsample == "convolution":
            self.downsample = ConvDownsample(in_ch, **common)
        elif downsample == "minimal":
            self.downsample = MinimalConvDownsample(in_ch, **common)
        elif downsample == "one-shot":
            self.downsample = OneShotDownsample(in_ch, **common)
        elif downsample == "ode":
            self.downsample = ODEDownsample(
                in_ch, adjoint=adjoint, t1=t1, tol=tol, method=method, **common
            )
        elif downsample == "ode2":
            self.downsample = ODEDownsample2(
                in_ch, adjoint=adjoint, t1=t1, tol=tol, method=method, **common
            )

        self.odeblock = ODEBlockELU(
            n_filters=n_filters,
            tol=tol,
            adjoint=adjoint,
            t1=t1,
            method=method,
            norm=norm,
        )
        self.classifier = FCClassifier(
            in_ch=n_filters, out=out, dropout=dropout, norm=norm
        )

    def forward(self, x):
        out = []
        x = self.downsample(x)
        if isinstance(x, (tuple, list)):
            (
                f,
                x,
            ) = x  # first elements are features, second is output to continue the forward
            if isinstance(
                self.classifier.module[-1], nn.Sequential
            ):  # no classification to be performed, apply GAP and return
                f = torch.stack(
                    [fi.mean(-1).mean(-1) for fi in f]
                )  # global avg pooling
            else:  # we want to apply the classifier to all features
                f = torch.stack([self.classifier(fi) for fi in f])
            out.append(f)

        x = self.odeblock(x)
        if x.dim() > 4:
            x = torch.stack([self.classifier(xi) for xi in x])
        else:
            x = self.classifier(x)

        out.append(x)
        out = torch.cat(out)
        return out

    def to_features_extractor(self, keep_pool=True):  # ugly hack
        if isinstance(self.downsample, (ODEDownsample, ODEDownsample2)):
            self.downsample.odeblock.return_last_only = False
        self.odeblock.return_last_only = False  # returns dynamic @ multiple timestamps
        if keep_pool:
            # remove last classification layer but maintain norm, relu and global avg pooling
            self.classifier.module[-1] = nn.Sequential()
        else:
            self.classifier = nn.Sequential(
                *list(self.classifier.module.children())[:2]
            )

    def nfe(self, reset=False):
        nfe = self.odeblock.nfe
        if reset:
            self.odeblock.nfe = 0
        return nfe


### my
class ODENetCTNR(nn.Module):
    def __init__(
        self,
        in_ch,
        out=10,
        n_filters=64,
        downsample="residual",
        method="dopri5",
        tol=1e-3,
        adjoint=False,
        t1=1,
        dropout=0,
        norm="group",
    ):
        super(ODENetCTNR, self).__init__()

        common = dict(out_ch=n_filters, norm=norm)
        if downsample == "residual":
            self.downsample = ResDownsample(in_ch, **common)
        elif downsample == "convolution":
            self.downsample = ConvDownsample(in_ch, **common)
        elif downsample == "minimal":
            self.downsample = MinimalConvDownsample(in_ch, **common)
        elif downsample == "one-shot":
            self.downsample = OneShotDownsample(in_ch, **common)
        elif downsample == "ode":
            self.downsample = ODEDownsample(
                in_ch, adjoint=adjoint, t1=t1, tol=tol, method=method, **common
            )
        elif downsample == "ode2":
            self.downsample = ODEDownsample2(
                in_ch, adjoint=adjoint, t1=t1, tol=tol, method=method, **common
            )

        self.odeblock = ODEBlockCTNR(
            n_filters=n_filters,
            tol=tol,
            adjoint=adjoint,
            t1=t1,
            method=method,
            norm=norm,
        )
        self.classifier = FCClassifier(
            in_ch=n_filters, out=out, dropout=dropout, norm=norm
        )

    def forward(self, x):
        out = []
        x = self.downsample(x)
        if isinstance(x, (tuple, list)):
            (
                f,
                x,
            ) = x  # first elements are features, second is output to continue the forward
            if isinstance(
                self.classifier.module[-1], nn.Sequential
            ):  # no classification to be performed, apply GAP and return
                f = torch.stack(
                    [fi.mean(-1).mean(-1) for fi in f]
                )  # global avg pooling
            else:  # we want to apply the classifier to all features
                f = torch.stack([self.classifier(fi) for fi in f])
            out.append(f)

        x = self.odeblock(x)
        if x.dim() > 4:
            x = torch.stack([self.classifier(xi) for xi in x])
        else:
            x = self.classifier(x)

        out.append(x)
        out = torch.cat(out)
        return out

    def to_features_extractor(self, keep_pool=True):  # ugly hack
        if isinstance(self.downsample, (ODEDownsample, ODEDownsample2)):
            self.downsample.odeblock.return_last_only = False
        self.odeblock.return_last_only = False  # returns dynamic @ multiple timestamps
        if keep_pool:
            # remove last classification layer but maintain norm, relu and global avg pooling
            self.classifier.module[-1] = nn.Sequential()
        else:
            self.classifier = nn.Sequential(
                *list(self.classifier.module.children())[:2]
            )

    def nfe(self, reset=False):
        nfe = self.odeblock.nfe
        if reset:
            self.odeblock.nfe = 0
        return nfe


### my
class ODEBlockCTNR(nn.Module):
    def __init__(
        self, n_filters=64, tol=1e-3, method="dopri5", adjoint=False, t1=1, norm="group"
    ):
        super(ODEBlockCTNR, self).__init__()
        self.odefunc = ODEfuncCTNR(n_filters, norm=norm)
        self.t1 = t1
        self.tol = tol
        self.method = method
        self.odeint = odeint_adjoint if adjoint else odeint
        self.return_last_only = True

    def forward(self, x):
        if self.integration_time is None:
            return x

        self.integration_time = self.integration_time.type_as(x)
        out = self.odeint(
            self.odefunc,
            x,
            self.integration_time,
            method=self.method,
            rtol=self.tol,
            atol=self.tol,
        )
        if self.return_last_only:
            out = out[-1]  # dynamics @ t=t1
        return out

    @property
    def nfe(self):
        return self.odefunc.nfe

    @nfe.setter
    def nfe(self, value):
        self.odefunc.nfe = value

    @property
    def t1(self):
        return self.integration_time[1]

    @t1.setter
    def t1(self, value):
        if isinstance(value, (int, float)):
            if value == 0:
                self.integration_time = None
            else:
                self.integration_time = torch.tensor([0, value], dtype=torch.float32)

        elif isinstance(value, (list, tuple, torch.Tensor)):
            if isinstance(value, tuple):
                value = list(value)
            if isinstance(value, torch.Tensor):
                value = value.tolist()
            if value[0] != 0:
                print(value[0])
                value = [
                    0,
                ] + value

            self.integration_time = torch.tensor(value, dtype=torch.float32)
        else:
            raise ValueError("Argument must be a scalar, a list, or a tensor")


### my
class ODEfuncCTNR(nn.Module):
    def __init__(self, dim, norm="group"):
        super(ODEfuncCTNR, self).__init__()
        norm = normalization(norm)
        self.norm1 = norm(dim)
        self.relu = nn.ReLU()
        self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1)
        self.norm2 = norm(dim)
        self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1)
        self.norm3 = norm(dim)
        self.nfe = 0

        ### my
        self.dense1 = nn.Linear(1, dim)
        self.dense2 = nn.Linear(1, dim)
        ###

    def forward(self, t, x):
        self.nfe += 1
        out = self.norm1(x)
        out = self.relu(out)
        out = self.conv1(out)
        ### CTNR
        tt = torch.ones(out.size(dim=0), 1).to(t.device) * t
        out = out + self.dense1(tt)[:, :, None, None]  # 128, 256, 1, 1
        ###
        out = self.norm2(out)
        out = self.relu(out)
        out = self.conv2(out)
        ### CTNR
        # tt = torch.ones(out.size(dim=0), 1).to(t.device) * t
        out = out + self.dense2(tt)[:, :, None, None]  # 128, 256, 1, 1
        ###
        out = self.norm3(out)
        return out


### my
class ODENetCNTR(nn.Module):
    def __init__(
        self,
        in_ch,
        out=10,
        n_filters=64,
        downsample="residual",
        method="dopri5",
        tol=1e-3,
        adjoint=False,
        t1=1,
        dropout=0,
        norm="group",
    ):
        super(ODENetCNTR, self).__init__()

        common = dict(out_ch=n_filters, norm=norm)
        if downsample == "residual":
            self.downsample = ResDownsample(in_ch, **common)
        elif downsample == "convolution":
            self.downsample = ConvDownsample(in_ch, **common)
        elif downsample == "minimal":
            self.downsample = MinimalConvDownsample(in_ch, **common)
        elif downsample == "one-shot":
            self.downsample = OneShotDownsample(in_ch, **common)
        elif downsample == "ode":
            self.downsample = ODEDownsample(
                in_ch, adjoint=adjoint, t1=t1, tol=tol, method=method, **common
            )
        elif downsample == "ode2":
            self.downsample = ODEDownsample2(
                in_ch, adjoint=adjoint, t1=t1, tol=tol, method=method, **common
            )

        self.odeblock = ODEBlockCNTR(
            n_filters=n_filters,
            tol=tol,
            adjoint=adjoint,
            t1=t1,
            method=method,
            norm=norm,
        )
        self.classifier = FCClassifier(
            in_ch=n_filters, out=out, dropout=dropout, norm=norm
        )

    def forward(self, x):
        out = []
        x = self.downsample(x)
        if isinstance(x, (tuple, list)):
            (
                f,
                x,
            ) = x  # first elements are features, second is output to continue the forward
            if isinstance(
                self.classifier.module[-1], nn.Sequential
            ):  # no classification to be performed, apply GAP and return
                f = torch.stack(
                    [fi.mean(-1).mean(-1) for fi in f]
                )  # global avg pooling
            else:  # we want to apply the classifier to all features
                f = torch.stack([self.classifier(fi) for fi in f])
            out.append(f)

        x = self.odeblock(x)
        if x.dim() > 4:
            x = torch.stack([self.classifier(xi) for xi in x])
        else:
            x = self.classifier(x)

        out.append(x)
        out = torch.cat(out)
        return out

    def to_features_extractor(self, keep_pool=True):  # ugly hack
        if isinstance(self.downsample, (ODEDownsample, ODEDownsample2)):
            self.downsample.odeblock.return_last_only = False
        self.odeblock.return_last_only = False  # returns dynamic @ multiple timestamps
        if keep_pool:
            # remove last classification layer but maintain norm, relu and global avg pooling
            self.classifier.module[-1] = nn.Sequential()
        else:
            self.classifier = nn.Sequential(
                *list(self.classifier.module.children())[:2]
            )

    def nfe(self, reset=False):
        nfe = self.odeblock.nfe
        if reset:
            self.odeblock.nfe = 0
        return nfe


### my
class ODEBlockCNTR(nn.Module):
    def __init__(
        self, n_filters=64, tol=1e-3, method="dopri5", adjoint=False, t1=1, norm="group"
    ):
        super(ODEBlockCNTR, self).__init__()
        self.odefunc = ODEfuncCNTR(n_filters, norm=norm)
        self.t1 = t1
        self.tol = tol
        self.method = method
        self.odeint = odeint_adjoint if adjoint else odeint
        self.return_last_only = True

    def forward(self, x):
        if self.integration_time is None:
            return x

        self.integration_time = self.integration_time.type_as(x)
        out = self.odeint(
            self.odefunc,
            x,
            self.integration_time,
            method=self.method,
            rtol=self.tol,
            atol=self.tol,
        )
        if self.return_last_only:
            out = out[-1]  # dynamics @ t=t1
        return out

    @property
    def nfe(self):
        return self.odefunc.nfe

    @nfe.setter
    def nfe(self, value):
        self.odefunc.nfe = value

    @property
    def t1(self):
        return self.integration_time[1]

    @t1.setter
    def t1(self, value):
        if isinstance(value, (int, float)):
            if value == 0:
                self.integration_time = None
            else:
                self.integration_time = torch.tensor([0, value], dtype=torch.float32)

        elif isinstance(value, (list, tuple, torch.Tensor)):
            if isinstance(value, tuple):
                value = list(value)
            if isinstance(value, torch.Tensor):
                value = value.tolist()
            if value[0] != 0:
                print(value[0])
                value = [
                    0,
                ] + value

            self.integration_time = torch.tensor(value, dtype=torch.float32)
        else:
            raise ValueError("Argument must be a scalar, a list, or a tensor")


### my
class ODEfuncCNTR(nn.Module):
    def __init__(self, dim, norm="group"):
        super(ODEfuncCNTR, self).__init__()
        norm = normalization(norm)
        self.norm1 = norm(dim)
        self.relu = nn.ReLU()
        self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1)
        self.norm2 = norm(dim)
        self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1)
        self.norm3 = norm(dim)
        self.nfe = 0

        ### my
        self.dense1 = nn.Linear(1, dim)
        self.dense2 = nn.Linear(1, dim)
        ###

    def forward(self, t, x):
        self.nfe += 1
        out = self.norm1(x)
        ### CNTR
        tt = torch.ones(out.size(dim=0), 1).to(t.device) * t
        out = out + self.dense1(tt)[:, :, None, None]  # 128, 256, 1, 1
        ###
        out = self.relu(out)
        out = self.conv1(out)
        out = self.norm2(out)
        ### CNTR
        # tt = torch.ones(out.size(dim=0), 1).to(t.device) * t
        out = out + self.dense2(tt)[:, :, None, None]  # 128, 256, 1, 1
        ###
        out = self.relu(out)
        out = self.conv2(out)
        out = self.norm3(out)
        return out


### my
class ODENetCNRT(nn.Module):
    def __init__(
        self,
        in_ch,
        out=10,
        n_filters=64,
        downsample="residual",
        method="dopri5",
        tol=1e-3,
        adjoint=False,
        t1=1,
        dropout=0,
        norm="group",
    ):
        super(ODENetCNRT, self).__init__()

        common = dict(out_ch=n_filters, norm=norm)
        if downsample == "residual":
            self.downsample = ResDownsample(in_ch, **common)
        elif downsample == "convolution":
            self.downsample = ConvDownsample(in_ch, **common)
        elif downsample == "minimal":
            self.downsample = MinimalConvDownsample(in_ch, **common)
        elif downsample == "one-shot":
            self.downsample = OneShotDownsample(in_ch, **common)
        elif downsample == "ode":
            self.downsample = ODEDownsample(
                in_ch, adjoint=adjoint, t1=t1, tol=tol, method=method, **common
            )
        elif downsample == "ode2":
            self.downsample = ODEDownsample2(
                in_ch, adjoint=adjoint, t1=t1, tol=tol, method=method, **common
            )

        self.odeblock = ODEBlockCNRT(
            n_filters=n_filters,
            tol=tol,
            adjoint=adjoint,
            t1=t1,
            method=method,
            norm=norm,
        )
        self.classifier = FCClassifier(
            in_ch=n_filters, out=out, dropout=dropout, norm=norm
        )

    def forward(self, x):
        out = []
        x = self.downsample(x)
        if isinstance(x, (tuple, list)):
            (
                f,
                x,
            ) = x  # first elements are features, second is output to continue the forward
            if isinstance(
                self.classifier.module[-1], nn.Sequential
            ):  # no classification to be performed, apply GAP and return
                f = torch.stack(
                    [fi.mean(-1).mean(-1) for fi in f]
                )  # global avg pooling
            else:  # we want to apply the classifier to all features
                f = torch.stack([self.classifier(fi) for fi in f])
            out.append(f)

        x = self.odeblock(x)
        if x.dim() > 4:
            x = torch.stack([self.classifier(xi) for xi in x])
        else:
            x = self.classifier(x)

        out.append(x)
        out = torch.cat(out)
        return out

    def to_features_extractor(self, keep_pool=True):  # ugly hack
        if isinstance(self.downsample, (ODEDownsample, ODEDownsample2)):
            self.downsample.odeblock.return_last_only = False
        self.odeblock.return_last_only = False  # returns dynamic @ multiple timestamps
        if keep_pool:
            # remove last classification layer but maintain norm, relu and global avg pooling
            self.classifier.module[-1] = nn.Sequential()
        else:
            self.classifier = nn.Sequential(
                *list(self.classifier.module.children())[:2]
            )

    def nfe(self, reset=False):
        nfe = self.odeblock.nfe
        if reset:
            self.odeblock.nfe = 0
        return nfe


### my
class ODEBlockCNRT(nn.Module):
    def __init__(
        self, n_filters=64, tol=1e-3, method="dopri5", adjoint=False, t1=1, norm="group"
    ):
        super(ODEBlockCNRT, self).__init__()
        self.odefunc = ODEfuncCNRT(n_filters, norm=norm)
        self.t1 = t1
        self.tol = tol
        self.method = method
        self.odeint = odeint_adjoint if adjoint else odeint
        self.return_last_only = True

    def forward(self, x):
        if self.integration_time is None:
            return x

        self.integration_time = self.integration_time.type_as(x)
        out = self.odeint(
            self.odefunc,
            x,
            self.integration_time,
            method=self.method,
            rtol=self.tol,
            atol=self.tol,
        )
        if self.return_last_only:
            out = out[-1]  # dynamics @ t=t1
        return out

    @property
    def nfe(self):
        return self.odefunc.nfe

    @nfe.setter
    def nfe(self, value):
        self.odefunc.nfe = value

    @property
    def t1(self):
        return self.integration_time[1]

    @t1.setter
    def t1(self, value):
        if isinstance(value, (int, float)):
            if value == 0:
                self.integration_time = None
            else:
                self.integration_time = torch.tensor([0, value], dtype=torch.float32)

        elif isinstance(value, (list, tuple, torch.Tensor)):
            if isinstance(value, tuple):
                value = list(value)
            if isinstance(value, torch.Tensor):
                value = value.tolist()
            if value[0] != 0:
                print(value[0])
                value = [
                    0,
                ] + value

            self.integration_time = torch.tensor(value, dtype=torch.float32)
        else:
            raise ValueError("Argument must be a scalar, a list, or a tensor")


### my
class ODEfuncCNRT(nn.Module):
    def __init__(self, dim, norm="group"):
        super(ODEfuncCNRT, self).__init__()
        norm = normalization(norm)
        self.norm1 = norm(dim)
        self.relu = nn.ReLU()
        self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1)
        self.norm2 = norm(dim)
        self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1)
        self.norm3 = norm(dim)
        self.nfe = 0

        ### my
        self.dense1 = nn.Linear(1, dim)
        self.dense2 = nn.Linear(1, dim)
        ###

    def forward(self, t, x):
        self.nfe += 1
        out = self.norm1(x)
        out = self.relu(out)
        ### CNRT
        tt = torch.ones(out.size(dim=0), 1).to(t.device) * t
        out = out + self.dense1(tt)[:, :, None, None]  # 128, 256, 1, 1
        ###
        out = self.conv1(out)
        out = self.norm2(out)
        out = self.relu(out)
        ### CNRT
        # tt = torch.ones(out.size(dim=0), 1).to(t.device) * t
        out = out + self.dense2(tt)[:, :, None, None]  # 128, 256, 1, 1
        ###
        out = self.conv2(out)
        out = self.norm3(out)
        return out


### my
class ODENetinit1en3(nn.Module):
    def __init__(
        self,
        in_ch,
        out=10,
        n_filters=64,
        downsample="residual",
        method="dopri5",
        tol=1e-3,
        adjoint=False,
        t1=1,
        dropout=0,
        norm="group",
    ):
        super(ODENetinit1en3, self).__init__()

        common = dict(out_ch=n_filters, norm=norm)
        if downsample == "residual":
            self.downsample = ResDownsample(in_ch, **common)
        elif downsample == "convolution":
            self.downsample = ConvDownsample(in_ch, **common)
        elif downsample == "minimal":
            self.downsample = MinimalConvDownsample(in_ch, **common)
        elif downsample == "one-shot":
            self.downsample = OneShotDownsample(in_ch, **common)
        elif downsample == "ode":
            self.downsample = ODEDownsample(
                in_ch, adjoint=adjoint, t1=t1, tol=tol, method=method, **common
            )
        elif downsample == "ode2":
            self.downsample = ODEDownsample2(
                in_ch, adjoint=adjoint, t1=t1, tol=tol, method=method, **common
            )

        self.odeblock = ODEBlockinit1en3(
            n_filters=n_filters,
            tol=tol,
            adjoint=adjoint,
            t1=t1,
            method=method,
            norm=norm,
        )
        self.classifier = FCClassifier(
            in_ch=n_filters, out=out, dropout=dropout, norm=norm
        )

    def forward(self, x):
        out = []
        x = self.downsample(x)
        if isinstance(x, (tuple, list)):
            (
                f,
                x,
            ) = x  # first elements are features, second is output to continue the forward
            if isinstance(
                self.classifier.module[-1], nn.Sequential
            ):  # no classification to be performed, apply GAP and return
                f = torch.stack(
                    [fi.mean(-1).mean(-1) for fi in f]
                )  # global avg pooling
            else:  # we want to apply the classifier to all features
                f = torch.stack([self.classifier(fi) for fi in f])
            out.append(f)

        x = self.odeblock(x)
        if x.dim() > 4:
            x = torch.stack([self.classifier(xi) for xi in x])
        else:
            x = self.classifier(x)

        out.append(x)
        out = torch.cat(out)
        return out

    def to_features_extractor(self, keep_pool=True):  # ugly hack
        if isinstance(self.downsample, (ODEDownsample, ODEDownsample2)):
            self.downsample.odeblock.return_last_only = False
        self.odeblock.return_last_only = False  # returns dynamic @ multiple timestamps
        if keep_pool:
            # remove last classification layer but maintain norm, relu and global avg pooling
            self.classifier.module[-1] = nn.Sequential()
        else:
            self.classifier = nn.Sequential(
                *list(self.classifier.module.children())[:2]
            )

    def nfe(self, reset=False):
        nfe = self.odeblock.nfe
        if reset:
            self.odeblock.nfe = 0
        return nfe


### my
class ODEBlockinit1en3(nn.Module):
    def __init__(
        self, n_filters=64, tol=1e-3, method="dopri5", adjoint=False, t1=1, norm="group"
    ):
        super(ODEBlockinit1en3, self).__init__()
        self.odefunc = ODEfuncinit1en3(n_filters, norm=norm)
        self.t1 = t1
        self.tol = tol
        self.method = method
        self.odeint = odeint_adjoint if adjoint else odeint
        self.return_last_only = True

    def forward(self, x):
        if self.integration_time is None:
            return x

        self.integration_time = self.integration_time.type_as(x)
        out = self.odeint(
            self.odefunc,
            x,
            self.integration_time,
            method=self.method,
            rtol=self.tol,
            atol=self.tol,
        )
        if self.return_last_only:
            out = out[-1]  # dynamics @ t=t1
        return out

    @property
    def nfe(self):
        return self.odefunc.nfe

    @nfe.setter
    def nfe(self, value):
        self.odefunc.nfe = value

    @property
    def t1(self):
        return self.integration_time[1]

    @t1.setter
    def t1(self, value):
        if isinstance(value, (int, float)):
            if value == 0:
                self.integration_time = None
            else:
                self.integration_time = torch.tensor([0, value], dtype=torch.float32)

        elif isinstance(value, (list, tuple, torch.Tensor)):
            if isinstance(value, tuple):
                value = list(value)
            if isinstance(value, torch.Tensor):
                value = value.tolist()
            if value[0] != 0:
                print(value[0])
                value = [
                    0,
                ] + value

            self.integration_time = torch.tensor(value, dtype=torch.float32)
        else:
            raise ValueError("Argument must be a scalar, a list, or a tensor")


### my
class ODEfuncinit1en3(nn.Module):
    def __init__(self, dim, norm="group"):
        super(ODEfuncinit1en3, self).__init__()
        norm = normalization(norm)
        self.norm1 = norm(dim)
        self.relu = nn.ReLU()
        self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1)
        self.norm2 = norm(dim)
        self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1)
        self.norm3 = norm(dim)
        self.nfe = 0

        ### my
        self.dense1 = nn.Linear(1, dim)
        self.dense2 = nn.Linear(1, dim)

        nn.init.normal_(self.dense1.weight, mean=0.0, std=1e-3)
        nn.init.normal_(self.dense2.weight, mean=0.0, std=1e-3)
        ###

    def forward(self, t, x):
        self.nfe += 1
        out = self.norm1(x)
        out = self.relu(out)
        out = self.conv1(out)
        ### init1en3
        tt = torch.ones(out.size(dim=0), 1).to(t.device) * t
        out = out + self.dense1(tt)[:, :, None, None]  # 128, 256, 1, 1
        ###
        out = self.norm2(out)
        out = self.relu(out)
        out = self.conv2(out)
        ### init1en3
        # tt = torch.ones(out.size(dim=0), 1).to(t.device) * t
        out = out + self.dense2(tt)[:, :, None, None]  # 128, 256, 1, 1
        ###
        out = self.norm3(out)
        return out


### my
class ODENetinit1en2(nn.Module):
    def __init__(
        self,
        in_ch,
        out=10,
        n_filters=64,
        downsample="residual",
        method="dopri5",
        tol=1e-3,
        adjoint=False,
        t1=1,
        dropout=0,
        norm="group",
    ):
        super(ODENetinit1en2, self).__init__()

        common = dict(out_ch=n_filters, norm=norm)
        if downsample == "residual":
            self.downsample = ResDownsample(in_ch, **common)
        elif downsample == "convolution":
            self.downsample = ConvDownsample(in_ch, **common)
        elif downsample == "minimal":
            self.downsample = MinimalConvDownsample(in_ch, **common)
        elif downsample == "one-shot":
            self.downsample = OneShotDownsample(in_ch, **common)
        elif downsample == "ode":
            self.downsample = ODEDownsample(
                in_ch, adjoint=adjoint, t1=t1, tol=tol, method=method, **common
            )
        elif downsample == "ode2":
            self.downsample = ODEDownsample2(
                in_ch, adjoint=adjoint, t1=t1, tol=tol, method=method, **common
            )

        self.odeblock = ODEBlockinit1en2(
            n_filters=n_filters,
            tol=tol,
            adjoint=adjoint,
            t1=t1,
            method=method,
            norm=norm,
        )
        self.classifier = FCClassifier(
            in_ch=n_filters, out=out, dropout=dropout, norm=norm
        )

    def forward(self, x):
        out = []
        x = self.downsample(x)
        if isinstance(x, (tuple, list)):
            (
                f,
                x,
            ) = x  # first elements are features, second is output to continue the forward
            if isinstance(
                self.classifier.module[-1], nn.Sequential
            ):  # no classification to be performed, apply GAP and return
                f = torch.stack(
                    [fi.mean(-1).mean(-1) for fi in f]
                )  # global avg pooling
            else:  # we want to apply the classifier to all features
                f = torch.stack([self.classifier(fi) for fi in f])
            out.append(f)

        x = self.odeblock(x)
        if x.dim() > 4:
            x = torch.stack([self.classifier(xi) for xi in x])
        else:
            x = self.classifier(x)

        out.append(x)
        out = torch.cat(out)
        return out

    def to_features_extractor(self, keep_pool=True):  # ugly hack
        if isinstance(self.downsample, (ODEDownsample, ODEDownsample2)):
            self.downsample.odeblock.return_last_only = False
        self.odeblock.return_last_only = False  # returns dynamic @ multiple timestamps
        if keep_pool:
            # remove last classification layer but maintain norm, relu and global avg pooling
            self.classifier.module[-1] = nn.Sequential()
        else:
            self.classifier = nn.Sequential(
                *list(self.classifier.module.children())[:2]
            )

    def nfe(self, reset=False):
        nfe = self.odeblock.nfe
        if reset:
            self.odeblock.nfe = 0
        return nfe


### my
class ODEBlockinit1en2(nn.Module):
    def __init__(
        self, n_filters=64, tol=1e-3, method="dopri5", adjoint=False, t1=1, norm="group"
    ):
        super(ODEBlockinit1en2, self).__init__()
        self.odefunc = ODEfuncinit1en2(n_filters, norm=norm)
        self.t1 = t1
        self.tol = tol
        self.method = method
        self.odeint = odeint_adjoint if adjoint else odeint
        self.return_last_only = True

    def forward(self, x):
        if self.integration_time is None:
            return x

        self.integration_time = self.integration_time.type_as(x)
        out = self.odeint(
            self.odefunc,
            x,
            self.integration_time,
            method=self.method,
            rtol=self.tol,
            atol=self.tol,
        )
        if self.return_last_only:
            out = out[-1]  # dynamics @ t=t1
        return out

    @property
    def nfe(self):
        return self.odefunc.nfe

    @nfe.setter
    def nfe(self, value):
        self.odefunc.nfe = value

    @property
    def t1(self):
        return self.integration_time[1]

    @t1.setter
    def t1(self, value):
        if isinstance(value, (int, float)):
            if value == 0:
                self.integration_time = None
            else:
                self.integration_time = torch.tensor([0, value], dtype=torch.float32)

        elif isinstance(value, (list, tuple, torch.Tensor)):
            if isinstance(value, tuple):
                value = list(value)
            if isinstance(value, torch.Tensor):
                value = value.tolist()
            if value[0] != 0:
                print(value[0])
                value = [
                    0,
                ] + value

            self.integration_time = torch.tensor(value, dtype=torch.float32)
        else:
            raise ValueError("Argument must be a scalar, a list, or a tensor")


### my
class ODEfuncinit1en2(nn.Module):
    def __init__(self, dim, norm="group"):
        super(ODEfuncinit1en2, self).__init__()
        norm = normalization(norm)
        self.norm1 = norm(dim)
        self.relu = nn.ReLU()
        self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1)
        self.norm2 = norm(dim)
        self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1)
        self.norm3 = norm(dim)
        self.nfe = 0

        ### my
        self.dense1 = nn.Linear(1, dim)
        self.dense2 = nn.Linear(1, dim)

        nn.init.normal_(self.dense1.weight, mean=0.0, std=1e-2)
        nn.init.normal_(self.dense2.weight, mean=0.0, std=1e-2)
        ###

    def forward(self, t, x):
        self.nfe += 1
        out = self.norm1(x)
        out = self.relu(out)
        out = self.conv1(out)
        ### init1en2
        tt = torch.ones(out.size(dim=0), 1).to(t.device) * t
        out = out + self.dense1(tt)[:, :, None, None]  # 128, 256, 1, 1
        ###
        out = self.norm2(out)
        out = self.relu(out)
        out = self.conv2(out)
        ### init1en2
        # tt = torch.ones(out.size(dim=0), 1).to(t.device) * t
        out = out + self.dense2(tt)[:, :, None, None]  # 128, 256, 1, 1
        ###
        out = self.norm3(out)
        return out


### my
class ODENetinit1en1(nn.Module):
    def __init__(
        self,
        in_ch,
        out=10,
        n_filters=64,
        downsample="residual",
        method="dopri5",
        tol=1e-3,
        adjoint=False,
        t1=1,
        dropout=0,
        norm="group",
    ):
        super(ODENetinit1en1, self).__init__()

        common = dict(out_ch=n_filters, norm=norm)
        if downsample == "residual":
            self.downsample = ResDownsample(in_ch, **common)
        elif downsample == "convolution":
            self.downsample = ConvDownsample(in_ch, **common)
        elif downsample == "minimal":
            self.downsample = MinimalConvDownsample(in_ch, **common)
        elif downsample == "one-shot":
            self.downsample = OneShotDownsample(in_ch, **common)
        elif downsample == "ode":
            self.downsample = ODEDownsample(
                in_ch, adjoint=adjoint, t1=t1, tol=tol, method=method, **common
            )
        elif downsample == "ode2":
            self.downsample = ODEDownsample2(
                in_ch, adjoint=adjoint, t1=t1, tol=tol, method=method, **common
            )

        self.odeblock = ODEBlockinit1en1(
            n_filters=n_filters,
            tol=tol,
            adjoint=adjoint,
            t1=t1,
            method=method,
            norm=norm,
        )
        self.classifier = FCClassifier(
            in_ch=n_filters, out=out, dropout=dropout, norm=norm
        )

    def forward(self, x):
        out = []
        x = self.downsample(x)
        if isinstance(x, (tuple, list)):
            (
                f,
                x,
            ) = x  # first elements are features, second is output to continue the forward
            if isinstance(
                self.classifier.module[-1], nn.Sequential
            ):  # no classification to be performed, apply GAP and return
                f = torch.stack(
                    [fi.mean(-1).mean(-1) for fi in f]
                )  # global avg pooling
            else:  # we want to apply the classifier to all features
                f = torch.stack([self.classifier(fi) for fi in f])
            out.append(f)

        x = self.odeblock(x)
        if x.dim() > 4:
            x = torch.stack([self.classifier(xi) for xi in x])
        else:
            x = self.classifier(x)

        out.append(x)
        out = torch.cat(out)
        return out

    def to_features_extractor(self, keep_pool=True):  # ugly hack
        if isinstance(self.downsample, (ODEDownsample, ODEDownsample2)):
            self.downsample.odeblock.return_last_only = False
        self.odeblock.return_last_only = False  # returns dynamic @ multiple timestamps
        if keep_pool:
            # remove last classification layer but maintain norm, relu and global avg pooling
            self.classifier.module[-1] = nn.Sequential()
        else:
            self.classifier = nn.Sequential(
                *list(self.classifier.module.children())[:2]
            )

    def nfe(self, reset=False):
        nfe = self.odeblock.nfe
        if reset:
            self.odeblock.nfe = 0
        return nfe


### my
class ODEBlockinit1en1(nn.Module):
    def __init__(
        self, n_filters=64, tol=1e-3, method="dopri5", adjoint=False, t1=1, norm="group"
    ):
        super(ODEBlockinit1en1, self).__init__()
        self.odefunc = ODEfuncinit1en1(n_filters, norm=norm)
        self.t1 = t1
        self.tol = tol
        self.method = method
        self.odeint = odeint_adjoint if adjoint else odeint
        self.return_last_only = True

    def forward(self, x):
        if self.integration_time is None:
            return x

        self.integration_time = self.integration_time.type_as(x)
        out = self.odeint(
            self.odefunc,
            x,
            self.integration_time,
            method=self.method,
            rtol=self.tol,
            atol=self.tol,
        )
        if self.return_last_only:
            out = out[-1]  # dynamics @ t=t1
        return out

    @property
    def nfe(self):
        return self.odefunc.nfe

    @nfe.setter
    def nfe(self, value):
        self.odefunc.nfe = value

    @property
    def t1(self):
        return self.integration_time[1]

    @t1.setter
    def t1(self, value):
        if isinstance(value, (int, float)):
            if value == 0:
                self.integration_time = None
            else:
                self.integration_time = torch.tensor([0, value], dtype=torch.float32)

        elif isinstance(value, (list, tuple, torch.Tensor)):
            if isinstance(value, tuple):
                value = list(value)
            if isinstance(value, torch.Tensor):
                value = value.tolist()
            if value[0] != 0:
                print(value[0])
                value = [
                    0,
                ] + value

            self.integration_time = torch.tensor(value, dtype=torch.float32)
        else:
            raise ValueError("Argument must be a scalar, a list, or a tensor")


### my
class ODEfuncinit1en1(nn.Module):
    def __init__(self, dim, norm="group"):
        super(ODEfuncinit1en1, self).__init__()
        norm = normalization(norm)
        self.norm1 = norm(dim)
        self.relu = nn.ReLU()
        self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1)
        self.norm2 = norm(dim)
        self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1)
        self.norm3 = norm(dim)
        self.nfe = 0

        ### my
        self.dense1 = nn.Linear(1, dim)
        self.dense2 = nn.Linear(1, dim)

        nn.init.normal_(self.dense1.weight, mean=0.0, std=1e-1)
        nn.init.normal_(self.dense2.weight, mean=0.0, std=1e-1)
        ###

    def forward(self, t, x):
        self.nfe += 1
        out = self.norm1(x)
        out = self.relu(out)
        out = self.conv1(out)
        ### init1en1
        tt = torch.ones(out.size(dim=0), 1).to(t.device) * t
        out = out + self.dense1(tt)[:, :, None, None]  # 128, 256, 1, 1
        ###
        out = self.norm2(out)
        out = self.relu(out)
        out = self.conv2(out)
        ### init1en1
        # tt = torch.ones(out.size(dim=0), 1).to(t.device) * t
        out = out + self.dense2(tt)[:, :, None, None]  # 128, 256, 1, 1
        ###
        out = self.norm3(out)
        return out


### my
class ODENetinit1ep0(nn.Module):
    def __init__(
        self,
        in_ch,
        out=10,
        n_filters=64,
        downsample="residual",
        method="dopri5",
        tol=1e-3,
        adjoint=False,
        t1=1,
        dropout=0,
        norm="group",
    ):
        super(ODENetinit1ep0, self).__init__()

        common = dict(out_ch=n_filters, norm=norm)
        if downsample == "residual":
            self.downsample = ResDownsample(in_ch, **common)
        elif downsample == "convolution":
            self.downsample = ConvDownsample(in_ch, **common)
        elif downsample == "minimal":
            self.downsample = MinimalConvDownsample(in_ch, **common)
        elif downsample == "one-shot":
            self.downsample = OneShotDownsample(in_ch, **common)
        elif downsample == "ode":
            self.downsample = ODEDownsample(
                in_ch, adjoint=adjoint, t1=t1, tol=tol, method=method, **common
            )
        elif downsample == "ode2":
            self.downsample = ODEDownsample2(
                in_ch, adjoint=adjoint, t1=t1, tol=tol, method=method, **common
            )

        self.odeblock = ODEBlockinit1ep0(
            n_filters=n_filters,
            tol=tol,
            adjoint=adjoint,
            t1=t1,
            method=method,
            norm=norm,
        )
        self.classifier = FCClassifier(
            in_ch=n_filters, out=out, dropout=dropout, norm=norm
        )

    def forward(self, x):
        out = []
        x = self.downsample(x)
        if isinstance(x, (tuple, list)):
            (
                f,
                x,
            ) = x  # first elements are features, second is output to continue the forward
            if isinstance(
                self.classifier.module[-1], nn.Sequential
            ):  # no classification to be performed, apply GAP and return
                f = torch.stack(
                    [fi.mean(-1).mean(-1) for fi in f]
                )  # global avg pooling
            else:  # we want to apply the classifier to all features
                f = torch.stack([self.classifier(fi) for fi in f])
            out.append(f)

        x = self.odeblock(x)
        if x.dim() > 4:
            x = torch.stack([self.classifier(xi) for xi in x])
        else:
            x = self.classifier(x)

        out.append(x)
        out = torch.cat(out)
        return out

    def to_features_extractor(self, keep_pool=True):  # ugly hack
        if isinstance(self.downsample, (ODEDownsample, ODEDownsample2)):
            self.downsample.odeblock.return_last_only = False
        self.odeblock.return_last_only = False  # returns dynamic @ multiple timestamps
        if keep_pool:
            # remove last classification layer but maintain norm, relu and global avg pooling
            self.classifier.module[-1] = nn.Sequential()
        else:
            self.classifier = nn.Sequential(
                *list(self.classifier.module.children())[:2]
            )

    def nfe(self, reset=False):
        nfe = self.odeblock.nfe
        if reset:
            self.odeblock.nfe = 0
        return nfe


### my
class ODEBlockinit1ep0(nn.Module):
    def __init__(
        self, n_filters=64, tol=1e-3, method="dopri5", adjoint=False, t1=1, norm="group"
    ):
        super(ODEBlockinit1ep0, self).__init__()
        self.odefunc = ODEfuncinit1ep0(n_filters, norm=norm)
        self.t1 = t1
        self.tol = tol
        self.method = method
        self.odeint = odeint_adjoint if adjoint else odeint
        self.return_last_only = True

    def forward(self, x):
        if self.integration_time is None:
            return x

        self.integration_time = self.integration_time.type_as(x)
        out = self.odeint(
            self.odefunc,
            x,
            self.integration_time,
            method=self.method,
            rtol=self.tol,
            atol=self.tol,
        )
        if self.return_last_only:
            out = out[-1]  # dynamics @ t=t1
        return out

    @property
    def nfe(self):
        return self.odefunc.nfe

    @nfe.setter
    def nfe(self, value):
        self.odefunc.nfe = value

    @property
    def t1(self):
        return self.integration_time[1]

    @t1.setter
    def t1(self, value):
        if isinstance(value, (int, float)):
            if value == 0:
                self.integration_time = None
            else:
                self.integration_time = torch.tensor([0, value], dtype=torch.float32)

        elif isinstance(value, (list, tuple, torch.Tensor)):
            if isinstance(value, tuple):
                value = list(value)
            if isinstance(value, torch.Tensor):
                value = value.tolist()
            if value[0] != 0:
                print(value[0])
                value = [
                    0,
                ] + value

            self.integration_time = torch.tensor(value, dtype=torch.float32)
        else:
            raise ValueError("Argument must be a scalar, a list, or a tensor")


### my
class ODEfuncinit1ep0(nn.Module):
    def __init__(self, dim, norm="group"):
        super(ODEfuncinit1ep0, self).__init__()
        norm = normalization(norm)
        self.norm1 = norm(dim)
        self.relu = nn.ReLU()
        self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1)
        self.norm2 = norm(dim)
        self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1)
        self.norm3 = norm(dim)
        self.nfe = 0

        ### my
        self.dense1 = nn.Linear(1, dim)
        self.dense2 = nn.Linear(1, dim)

        nn.init.normal_(self.dense1.weight, mean=0.0, std=1e0)
        nn.init.normal_(self.dense2.weight, mean=0.0, std=1e0)
        ###

    def forward(self, t, x):
        self.nfe += 1
        out = self.norm1(x)
        out = self.relu(out)
        out = self.conv1(out)
        ### init1ep0
        tt = torch.ones(out.size(dim=0), 1).to(t.device) * t
        out = out + self.dense1(tt)[:, :, None, None]  # 128, 256, 1, 1
        ###
        out = self.norm2(out)
        out = self.relu(out)
        out = self.conv2(out)
        ### init1ep0
        # tt = torch.ones(out.size(dim=0), 1).to(t.device) * t
        out = out + self.dense2(tt)[:, :, None, None]  # 128, 256, 1, 1
        ###
        out = self.norm3(out)
        return out


### my
class ODENetinit1ep1(nn.Module):
    def __init__(
        self,
        in_ch,
        out=10,
        n_filters=64,
        downsample="residual",
        method="dopri5",
        tol=1e-3,
        adjoint=False,
        t1=1,
        dropout=0,
        norm="group",
    ):
        super(ODENetinit1ep1, self).__init__()

        common = dict(out_ch=n_filters, norm=norm)
        if downsample == "residual":
            self.downsample = ResDownsample(in_ch, **common)
        elif downsample == "convolution":
            self.downsample = ConvDownsample(in_ch, **common)
        elif downsample == "minimal":
            self.downsample = MinimalConvDownsample(in_ch, **common)
        elif downsample == "one-shot":
            self.downsample = OneShotDownsample(in_ch, **common)
        elif downsample == "ode":
            self.downsample = ODEDownsample(
                in_ch, adjoint=adjoint, t1=t1, tol=tol, method=method, **common
            )
        elif downsample == "ode2":
            self.downsample = ODEDownsample2(
                in_ch, adjoint=adjoint, t1=t1, tol=tol, method=method, **common
            )

        self.odeblock = ODEBlockinit1ep1(
            n_filters=n_filters,
            tol=tol,
            adjoint=adjoint,
            t1=t1,
            method=method,
            norm=norm,
        )
        self.classifier = FCClassifier(
            in_ch=n_filters, out=out, dropout=dropout, norm=norm
        )

    def forward(self, x):
        out = []
        x = self.downsample(x)
        if isinstance(x, (tuple, list)):
            (
                f,
                x,
            ) = x  # first elements are features, second is output to continue the forward
            if isinstance(
                self.classifier.module[-1], nn.Sequential
            ):  # no classification to be performed, apply GAP and return
                f = torch.stack(
                    [fi.mean(-1).mean(-1) for fi in f]
                )  # global avg pooling
            else:  # we want to apply the classifier to all features
                f = torch.stack([self.classifier(fi) for fi in f])
            out.append(f)

        x = self.odeblock(x)
        if x.dim() > 4:
            x = torch.stack([self.classifier(xi) for xi in x])
        else:
            x = self.classifier(x)

        out.append(x)
        out = torch.cat(out)
        return out

    def to_features_extractor(self, keep_pool=True):  # ugly hack
        if isinstance(self.downsample, (ODEDownsample, ODEDownsample2)):
            self.downsample.odeblock.return_last_only = False
        self.odeblock.return_last_only = False  # returns dynamic @ multiple timestamps
        if keep_pool:
            # remove last classification layer but maintain norm, relu and global avg pooling
            self.classifier.module[-1] = nn.Sequential()
        else:
            self.classifier = nn.Sequential(
                *list(self.classifier.module.children())[:2]
            )

    def nfe(self, reset=False):
        nfe = self.odeblock.nfe
        if reset:
            self.odeblock.nfe = 0
        return nfe


### my
class ODEBlockinit1ep1(nn.Module):
    def __init__(
        self, n_filters=64, tol=1e-3, method="dopri5", adjoint=False, t1=1, norm="group"
    ):
        super(ODEBlockinit1ep1, self).__init__()
        self.odefunc = ODEfuncinit1ep1(n_filters, norm=norm)
        self.t1 = t1
        self.tol = tol
        self.method = method
        self.odeint = odeint_adjoint if adjoint else odeint
        self.return_last_only = True

    def forward(self, x):
        if self.integration_time is None:
            return x

        self.integration_time = self.integration_time.type_as(x)
        out = self.odeint(
            self.odefunc,
            x,
            self.integration_time,
            method=self.method,
            rtol=self.tol,
            atol=self.tol,
        )
        if self.return_last_only:
            out = out[-1]  # dynamics @ t=t1
        return out

    @property
    def nfe(self):
        return self.odefunc.nfe

    @nfe.setter
    def nfe(self, value):
        self.odefunc.nfe = value

    @property
    def t1(self):
        return self.integration_time[1]

    @t1.setter
    def t1(self, value):
        if isinstance(value, (int, float)):
            if value == 0:
                self.integration_time = None
            else:
                self.integration_time = torch.tensor([0, value], dtype=torch.float32)

        elif isinstance(value, (list, tuple, torch.Tensor)):
            if isinstance(value, tuple):
                value = list(value)
            if isinstance(value, torch.Tensor):
                value = value.tolist()
            if value[0] != 0:
                print(value[0])
                value = [
                    0,
                ] + value

            self.integration_time = torch.tensor(value, dtype=torch.float32)
        else:
            raise ValueError("Argument must be a scalar, a list, or a tensor")


### my
class ODEfuncinit1ep1(nn.Module):
    def __init__(self, dim, norm="group"):
        super(ODEfuncinit1ep1, self).__init__()
        norm = normalization(norm)
        self.norm1 = norm(dim)
        self.relu = nn.ReLU()
        self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1)
        self.norm2 = norm(dim)
        self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1)
        self.norm3 = norm(dim)
        self.nfe = 0

        ### my
        self.dense1 = nn.Linear(1, dim)
        self.dense2 = nn.Linear(1, dim)

        nn.init.normal_(self.dense1.weight, mean=0.0, std=1e1)
        nn.init.normal_(self.dense2.weight, mean=0.0, std=1e1)
        ###

    def forward(self, t, x):
        self.nfe += 1
        out = self.norm1(x)
        out = self.relu(out)
        out = self.conv1(out)
        ### init1ep1
        tt = torch.ones(out.size(dim=0), 1).to(t.device) * t
        out = out + self.dense1(tt)[:, :, None, None]  # 128, 256, 1, 1
        ###
        out = self.norm2(out)
        out = self.relu(out)
        out = self.conv2(out)
        ### init1ep1
        # tt = torch.ones(out.size(dim=0), 1).to(t.device) * t
        out = out + self.dense2(tt)[:, :, None, None]  # 128, 256, 1, 1
        ###
        out = self.norm3(out)
        return out


### my
class ODENetinit1ep2(nn.Module):
    def __init__(
        self,
        in_ch,
        out=10,
        n_filters=64,
        downsample="residual",
        method="dopri5",
        tol=1e-3,
        adjoint=False,
        t1=1,
        dropout=0,
        norm="group",
    ):
        super(ODENetinit1ep2, self).__init__()

        common = dict(out_ch=n_filters, norm=norm)
        if downsample == "residual":
            self.downsample = ResDownsample(in_ch, **common)
        elif downsample == "convolution":
            self.downsample = ConvDownsample(in_ch, **common)
        elif downsample == "minimal":
            self.downsample = MinimalConvDownsample(in_ch, **common)
        elif downsample == "one-shot":
            self.downsample = OneShotDownsample(in_ch, **common)
        elif downsample == "ode":
            self.downsample = ODEDownsample(
                in_ch, adjoint=adjoint, t1=t1, tol=tol, method=method, **common
            )
        elif downsample == "ode2":
            self.downsample = ODEDownsample2(
                in_ch, adjoint=adjoint, t1=t1, tol=tol, method=method, **common
            )

        self.odeblock = ODEBlockinit1ep2(
            n_filters=n_filters,
            tol=tol,
            adjoint=adjoint,
            t1=t1,
            method=method,
            norm=norm,
        )
        self.classifier = FCClassifier(
            in_ch=n_filters, out=out, dropout=dropout, norm=norm
        )

    def forward(self, x):
        out = []
        x = self.downsample(x)
        if isinstance(x, (tuple, list)):
            (
                f,
                x,
            ) = x  # first elements are features, second is output to continue the forward
            if isinstance(
                self.classifier.module[-1], nn.Sequential
            ):  # no classification to be performed, apply GAP and return
                f = torch.stack(
                    [fi.mean(-1).mean(-1) for fi in f]
                )  # global avg pooling
            else:  # we want to apply the classifier to all features
                f = torch.stack([self.classifier(fi) for fi in f])
            out.append(f)

        x = self.odeblock(x)
        if x.dim() > 4:
            x = torch.stack([self.classifier(xi) for xi in x])
        else:
            x = self.classifier(x)

        out.append(x)
        out = torch.cat(out)
        return out

    def to_features_extractor(self, keep_pool=True):  # ugly hack
        if isinstance(self.downsample, (ODEDownsample, ODEDownsample2)):
            self.downsample.odeblock.return_last_only = False
        self.odeblock.return_last_only = False  # returns dynamic @ multiple timestamps
        if keep_pool:
            # remove last classification layer but maintain norm, relu and global avg pooling
            self.classifier.module[-1] = nn.Sequential()
        else:
            self.classifier = nn.Sequential(
                *list(self.classifier.module.children())[:2]
            )

    def nfe(self, reset=False):
        nfe = self.odeblock.nfe
        if reset:
            self.odeblock.nfe = 0
        return nfe


### my
class ODEBlockinit1ep2(nn.Module):
    def __init__(
        self, n_filters=64, tol=1e-3, method="dopri5", adjoint=False, t1=1, norm="group"
    ):
        super(ODEBlockinit1ep2, self).__init__()
        self.odefunc = ODEfuncinit1ep2(n_filters, norm=norm)
        self.t1 = t1
        self.tol = tol
        self.method = method
        self.odeint = odeint_adjoint if adjoint else odeint
        self.return_last_only = True

    def forward(self, x):
        if self.integration_time is None:
            return x

        self.integration_time = self.integration_time.type_as(x)
        out = self.odeint(
            self.odefunc,
            x,
            self.integration_time,
            method=self.method,
            rtol=self.tol,
            atol=self.tol,
        )
        if self.return_last_only:
            out = out[-1]  # dynamics @ t=t1
        return out

    @property
    def nfe(self):
        return self.odefunc.nfe

    @nfe.setter
    def nfe(self, value):
        self.odefunc.nfe = value

    @property
    def t1(self):
        return self.integration_time[1]

    @t1.setter
    def t1(self, value):
        if isinstance(value, (int, float)):
            if value == 0:
                self.integration_time = None
            else:
                self.integration_time = torch.tensor([0, value], dtype=torch.float32)

        elif isinstance(value, (list, tuple, torch.Tensor)):
            if isinstance(value, tuple):
                value = list(value)
            if isinstance(value, torch.Tensor):
                value = value.tolist()
            if value[0] != 0:
                print(value[0])
                value = [
                    0,
                ] + value

            self.integration_time = torch.tensor(value, dtype=torch.float32)
        else:
            raise ValueError("Argument must be a scalar, a list, or a tensor")


### my
class ODEfuncinit1ep2(nn.Module):
    def __init__(self, dim, norm="group"):
        super(ODEfuncinit1ep2, self).__init__()
        norm = normalization(norm)
        self.norm1 = norm(dim)
        self.relu = nn.ReLU()
        self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1)
        self.norm2 = norm(dim)
        self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1)
        self.norm3 = norm(dim)
        self.nfe = 0

        ### my
        self.dense1 = nn.Linear(1, dim)
        self.dense2 = nn.Linear(1, dim)

        nn.init.normal_(self.dense1.weight, mean=0.0, std=1e2)
        nn.init.normal_(self.dense2.weight, mean=0.0, std=1e2)
        ###

    def forward(self, t, x):
        self.nfe += 1
        out = self.norm1(x)
        out = self.relu(out)
        out = self.conv1(out)
        ### init1ep2
        tt = torch.ones(out.size(dim=0), 1).to(t.device) * t
        out = out + self.dense1(tt)[:, :, None, None]  # 128, 256, 1, 1
        ###
        out = self.norm2(out)
        out = self.relu(out)
        out = self.conv2(out)
        ### init1ep2
        # tt = torch.ones(out.size(dim=0), 1).to(t.device) * t
        out = out + self.dense2(tt)[:, :, None, None]  # 128, 256, 1, 1
        ###
        out = self.norm3(out)
        return out


### my
class ODENetinit1ep3(nn.Module):
    def __init__(
        self,
        in_ch,
        out=10,
        n_filters=64,
        downsample="residual",
        method="dopri5",
        tol=1e-3,
        adjoint=False,
        t1=1,
        dropout=0,
        norm="group",
    ):
        super(ODENetinit1ep3, self).__init__()

        common = dict(out_ch=n_filters, norm=norm)
        if downsample == "residual":
            self.downsample = ResDownsample(in_ch, **common)
        elif downsample == "convolution":
            self.downsample = ConvDownsample(in_ch, **common)
        elif downsample == "minimal":
            self.downsample = MinimalConvDownsample(in_ch, **common)
        elif downsample == "one-shot":
            self.downsample = OneShotDownsample(in_ch, **common)
        elif downsample == "ode":
            self.downsample = ODEDownsample(
                in_ch, adjoint=adjoint, t1=t1, tol=tol, method=method, **common
            )
        elif downsample == "ode2":
            self.downsample = ODEDownsample2(
                in_ch, adjoint=adjoint, t1=t1, tol=tol, method=method, **common
            )

        self.odeblock = ODEBlockinit1ep3(
            n_filters=n_filters,
            tol=tol,
            adjoint=adjoint,
            t1=t1,
            method=method,
            norm=norm,
        )
        self.classifier = FCClassifier(
            in_ch=n_filters, out=out, dropout=dropout, norm=norm
        )

    def forward(self, x):
        out = []
        x = self.downsample(x)
        if isinstance(x, (tuple, list)):
            (
                f,
                x,
            ) = x  # first elements are features, second is output to continue the forward
            if isinstance(
                self.classifier.module[-1], nn.Sequential
            ):  # no classification to be performed, apply GAP and return
                f = torch.stack(
                    [fi.mean(-1).mean(-1) for fi in f]
                )  # global avg pooling
            else:  # we want to apply the classifier to all features
                f = torch.stack([self.classifier(fi) for fi in f])
            out.append(f)

        x = self.odeblock(x)
        if x.dim() > 4:
            x = torch.stack([self.classifier(xi) for xi in x])
        else:
            x = self.classifier(x)

        out.append(x)
        out = torch.cat(out)
        return out

    def to_features_extractor(self, keep_pool=True):  # ugly hack
        if isinstance(self.downsample, (ODEDownsample, ODEDownsample2)):
            self.downsample.odeblock.return_last_only = False
        self.odeblock.return_last_only = False  # returns dynamic @ multiple timestamps
        if keep_pool:
            # remove last classification layer but maintain norm, relu and global avg pooling
            self.classifier.module[-1] = nn.Sequential()
        else:
            self.classifier = nn.Sequential(
                *list(self.classifier.module.children())[:2]
            )

    def nfe(self, reset=False):
        nfe = self.odeblock.nfe
        if reset:
            self.odeblock.nfe = 0
        return nfe


### my
class ODEBlockinit1ep3(nn.Module):
    def __init__(
        self, n_filters=64, tol=1e-3, method="dopri5", adjoint=False, t1=1, norm="group"
    ):
        super(ODEBlockinit1ep3, self).__init__()
        self.odefunc = ODEfuncinit1ep3(n_filters, norm=norm)
        self.t1 = t1
        self.tol = tol
        self.method = method
        self.odeint = odeint_adjoint if adjoint else odeint
        self.return_last_only = True

    def forward(self, x):
        if self.integration_time is None:
            return x

        self.integration_time = self.integration_time.type_as(x)
        out = self.odeint(
            self.odefunc,
            x,
            self.integration_time,
            method=self.method,
            rtol=self.tol,
            atol=self.tol,
        )
        if self.return_last_only:
            out = out[-1]  # dynamics @ t=t1
        return out

    @property
    def nfe(self):
        return self.odefunc.nfe

    @nfe.setter
    def nfe(self, value):
        self.odefunc.nfe = value

    @property
    def t1(self):
        return self.integration_time[1]

    @t1.setter
    def t1(self, value):
        if isinstance(value, (int, float)):
            if value == 0:
                self.integration_time = None
            else:
                self.integration_time = torch.tensor([0, value], dtype=torch.float32)

        elif isinstance(value, (list, tuple, torch.Tensor)):
            if isinstance(value, tuple):
                value = list(value)
            if isinstance(value, torch.Tensor):
                value = value.tolist()
            if value[0] != 0:
                print(value[0])
                value = [
                    0,
                ] + value

            self.integration_time = torch.tensor(value, dtype=torch.float32)
        else:
            raise ValueError("Argument must be a scalar, a list, or a tensor")


### my
class ODEfuncinit1ep3(nn.Module):
    def __init__(self, dim, norm="group"):
        super(ODEfuncinit1ep3, self).__init__()
        norm = normalization(norm)
        self.norm1 = norm(dim)
        self.relu = nn.ReLU()
        self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1)
        self.norm2 = norm(dim)
        self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1)
        self.norm3 = norm(dim)
        self.nfe = 0

        ### my
        self.dense1 = nn.Linear(1, dim)
        self.dense2 = nn.Linear(1, dim)

        nn.init.normal_(self.dense1.weight, mean=0.0, std=1e3)
        nn.init.normal_(self.dense2.weight, mean=0.0, std=1e3)
        ###

    def forward(self, t, x):
        self.nfe += 1
        out = self.norm1(x)
        out = self.relu(out)
        out = self.conv1(out)
        ### init1ep3
        tt = torch.ones(out.size(dim=0), 1).to(t.device) * t
        out = out + self.dense1(tt)[:, :, None, None]  # 128, 256, 1, 1
        ###
        out = self.norm2(out)
        out = self.relu(out)
        out = self.conv2(out)
        ### init1ep3
        # tt = torch.ones(out.size(dim=0), 1).to(t.device) * t
        out = out + self.dense2(tt)[:, :, None, None]  # 128, 256, 1, 1
        ###
        out = self.norm3(out)
        return out


### my
class ODENetinit1ep4(nn.Module):
    def __init__(
        self,
        in_ch,
        out=10,
        n_filters=64,
        downsample="residual",
        method="dopri5",
        tol=1e-3,
        adjoint=False,
        t1=1,
        dropout=0,
        norm="group",
    ):
        super(ODENetinit1ep4, self).__init__()

        common = dict(out_ch=n_filters, norm=norm)
        if downsample == "residual":
            self.downsample = ResDownsample(in_ch, **common)
        elif downsample == "convolution":
            self.downsample = ConvDownsample(in_ch, **common)
        elif downsample == "minimal":
            self.downsample = MinimalConvDownsample(in_ch, **common)
        elif downsample == "one-shot":
            self.downsample = OneShotDownsample(in_ch, **common)
        elif downsample == "ode":
            self.downsample = ODEDownsample(
                in_ch, adjoint=adjoint, t1=t1, tol=tol, method=method, **common
            )
        elif downsample == "ode2":
            self.downsample = ODEDownsample2(
                in_ch, adjoint=adjoint, t1=t1, tol=tol, method=method, **common
            )

        self.odeblock = ODEBlockinit1ep4(
            n_filters=n_filters,
            tol=tol,
            adjoint=adjoint,
            t1=t1,
            method=method,
            norm=norm,
        )
        self.classifier = FCClassifier(
            in_ch=n_filters, out=out, dropout=dropout, norm=norm
        )

    def forward(self, x):
        out = []
        x = self.downsample(x)
        if isinstance(x, (tuple, list)):
            (
                f,
                x,
            ) = x  # first elements are features, second is output to continue the forward
            if isinstance(
                self.classifier.module[-1], nn.Sequential
            ):  # no classification to be performed, apply GAP and return
                f = torch.stack(
                    [fi.mean(-1).mean(-1) for fi in f]
                )  # global avg pooling
            else:  # we want to apply the classifier to all features
                f = torch.stack([self.classifier(fi) for fi in f])
            out.append(f)

        x = self.odeblock(x)
        if x.dim() > 4:
            x = torch.stack([self.classifier(xi) for xi in x])
        else:
            x = self.classifier(x)

        out.append(x)
        out = torch.cat(out)
        return out

    def to_features_extractor(self, keep_pool=True):  # ugly hack
        if isinstance(self.downsample, (ODEDownsample, ODEDownsample2)):
            self.downsample.odeblock.return_last_only = False
        self.odeblock.return_last_only = False  # returns dynamic @ multiple timestamps
        if keep_pool:
            # remove last classification layer but maintain norm, relu and global avg pooling
            self.classifier.module[-1] = nn.Sequential()
        else:
            self.classifier = nn.Sequential(
                *list(self.classifier.module.children())[:2]
            )

    def nfe(self, reset=False):
        nfe = self.odeblock.nfe
        if reset:
            self.odeblock.nfe = 0
        return nfe


### my
class ODEBlockinit1ep4(nn.Module):
    def __init__(
        self, n_filters=64, tol=1e-3, method="dopri5", adjoint=False, t1=1, norm="group"
    ):
        super(ODEBlockinit1ep4, self).__init__()
        self.odefunc = ODEfuncinit1ep4(n_filters, norm=norm)
        self.t1 = t1
        self.tol = tol
        self.method = method
        self.odeint = odeint_adjoint if adjoint else odeint
        self.return_last_only = True

    def forward(self, x):
        if self.integration_time is None:
            return x

        self.integration_time = self.integration_time.type_as(x)
        out = self.odeint(
            self.odefunc,
            x,
            self.integration_time,
            method=self.method,
            rtol=self.tol,
            atol=self.tol,
        )
        if self.return_last_only:
            out = out[-1]  # dynamics @ t=t1
        return out

    @property
    def nfe(self):
        return self.odefunc.nfe

    @nfe.setter
    def nfe(self, value):
        self.odefunc.nfe = value

    @property
    def t1(self):
        return self.integration_time[1]

    @t1.setter
    def t1(self, value):
        if isinstance(value, (int, float)):
            if value == 0:
                self.integration_time = None
            else:
                self.integration_time = torch.tensor([0, value], dtype=torch.float32)

        elif isinstance(value, (list, tuple, torch.Tensor)):
            if isinstance(value, tuple):
                value = list(value)
            if isinstance(value, torch.Tensor):
                value = value.tolist()
            if value[0] != 0:
                print(value[0])
                value = [
                    0,
                ] + value

            self.integration_time = torch.tensor(value, dtype=torch.float32)
        else:
            raise ValueError("Argument must be a scalar, a list, or a tensor")


### my
class ODEfuncinit1ep4(nn.Module):
    def __init__(self, dim, norm="group"):
        super(ODEfuncinit1ep4, self).__init__()
        norm = normalization(norm)
        self.norm1 = norm(dim)
        self.relu = nn.ReLU()
        self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1)
        self.norm2 = norm(dim)
        self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1)
        self.norm3 = norm(dim)
        self.nfe = 0

        ### my
        self.dense1 = nn.Linear(1, dim)
        self.dense2 = nn.Linear(1, dim)

        nn.init.normal_(self.dense1.weight, mean=0.0, std=1e4)
        nn.init.normal_(self.dense2.weight, mean=0.0, std=1e4)
        ###

    def forward(self, t, x):
        self.nfe += 1
        out = self.norm1(x)
        out = self.relu(out)
        out = self.conv1(out)
        ### init1ep4
        tt = torch.ones(out.size(dim=0), 1).to(t.device) * t
        out = out + self.dense1(tt)[:, :, None, None]  # 128, 256, 1, 1
        ###
        out = self.norm2(out)
        out = self.relu(out)
        out = self.conv2(out)
        ### init1ep4
        # tt = torch.ones(out.size(dim=0), 1).to(t.device) * t
        out = out + self.dense2(tt)[:, :, None, None]  # 128, 256, 1, 1
        ###
        out = self.norm3(out)
        return out


### my
class ODENetzerobiasall(nn.Module):
    def __init__(
        self,
        in_ch,
        out=10,
        n_filters=64,
        downsample="residual",
        method="dopri5",
        tol=1e-3,
        adjoint=False,
        t1=1,
        dropout=0,
        norm="group",
    ):
        super(ODENetzerobiasall, self).__init__()

        common = dict(out_ch=n_filters, norm=norm)
        if downsample == "residual":
            self.downsample = ResDownsample(in_ch, **common)
        elif downsample == "convolution":
            self.downsample = ConvDownsample(in_ch, **common)
        elif downsample == "minimal":
            self.downsample = MinimalConvDownsample(in_ch, **common)
        elif downsample == "one-shot":
            self.downsample = OneShotDownsample(in_ch, **common)
        elif downsample == "ode":
            self.downsample = ODEDownsample(
                in_ch, adjoint=adjoint, t1=t1, tol=tol, method=method, **common
            )
        elif downsample == "ode2":
            self.downsample = ODEDownsample2(
                in_ch, adjoint=adjoint, t1=t1, tol=tol, method=method, **common
            )

        self.odeblock = ODEBlockzerobiasall(
            n_filters=n_filters,
            tol=tol,
            adjoint=adjoint,
            t1=t1,
            method=method,
            norm=norm,
        )
        self.classifier = FCClassifier(
            in_ch=n_filters, out=out, dropout=dropout, norm=norm
        )

    def forward(self, x):
        out = []
        x = self.downsample(x)
        if isinstance(x, (tuple, list)):
            (
                f,
                x,
            ) = x  # first elements are features, second is output to continue the forward
            if isinstance(
                self.classifier.module[-1], nn.Sequential
            ):  # no classification to be performed, apply GAP and return
                f = torch.stack(
                    [fi.mean(-1).mean(-1) for fi in f]
                )  # global avg pooling
            else:  # we want to apply the classifier to all features
                f = torch.stack([self.classifier(fi) for fi in f])
            out.append(f)

        x = self.odeblock(x)
        if x.dim() > 4:
            x = torch.stack([self.classifier(xi) for xi in x])
        else:
            x = self.classifier(x)

        out.append(x)
        out = torch.cat(out)
        return out

    def to_features_extractor(self, keep_pool=True):  # ugly hack
        if isinstance(self.downsample, (ODEDownsample, ODEDownsample2)):
            self.downsample.odeblock.return_last_only = False
        self.odeblock.return_last_only = False  # returns dynamic @ multiple timestamps
        if keep_pool:
            # remove last classification layer but maintain norm, relu and global avg pooling
            self.classifier.module[-1] = nn.Sequential()
        else:
            self.classifier = nn.Sequential(
                *list(self.classifier.module.children())[:2]
            )

    def nfe(self, reset=False):
        nfe = self.odeblock.nfe
        if reset:
            self.odeblock.nfe = 0
        return nfe


### my
class ODEBlockzerobiasall(nn.Module):
    def __init__(
        self, n_filters=64, tol=1e-3, method="dopri5", adjoint=False, t1=1, norm="group"
    ):
        super(ODEBlockzerobiasall, self).__init__()
        self.odefunc = ODEfunczerobiasall(n_filters, norm=norm)
        self.t1 = t1
        self.tol = tol
        self.method = method
        self.odeint = odeint_adjoint if adjoint else odeint
        self.return_last_only = True

    def forward(self, x):
        if self.integration_time is None:
            return x

        self.integration_time = self.integration_time.type_as(x)
        out = self.odeint(
            self.odefunc,
            x,
            self.integration_time,
            method=self.method,
            rtol=self.tol,
            atol=self.tol,
        )
        if self.return_last_only:
            out = out[-1]  # dynamics @ t=t1
        return out

    @property
    def nfe(self):
        return self.odefunc.nfe

    @nfe.setter
    def nfe(self, value):
        self.odefunc.nfe = value

    @property
    def t1(self):
        return self.integration_time[1]

    @t1.setter
    def t1(self, value):
        if isinstance(value, (int, float)):
            if value == 0:
                self.integration_time = None
            else:
                self.integration_time = torch.tensor([0, value], dtype=torch.float32)

        elif isinstance(value, (list, tuple, torch.Tensor)):
            if isinstance(value, tuple):
                value = list(value)
            if isinstance(value, torch.Tensor):
                value = value.tolist()
            if value[0] != 0:
                print(value[0])
                value = [
                    0,
                ] + value

            self.integration_time = torch.tensor(value, dtype=torch.float32)
        else:
            raise ValueError("Argument must be a scalar, a list, or a tensor")


### my
class ODEfunczerobiasall(nn.Module):
    def __init__(self, dim, norm="group"):
        super(ODEfunczerobiasall, self).__init__()
        norm = normalization(norm)
        self.norm1 = norm(dim)
        self.relu = nn.ReLU()
        self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1)
        self.norm2 = norm(dim)
        self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1)
        self.norm3 = norm(dim)
        self.nfe = 0

        ### my
        self.dense1 = nn.Linear(1, dim)
        self.dense2 = nn.Linear(1, dim)

        ### zerobiasall
        nn.init.zeros_(self.conv1.bias)
        nn.init.zeros_(self.conv2.bias)
        nn.init.zeros_(self.dense1.bias)
        nn.init.zeros_(self.dense2.bias)
        ###

    def forward(self, t, x):
        self.nfe += 1
        out = self.norm1(x)
        out = self.relu(out)
        out = self.conv1(out)
        ### zerobiasall
        tt = torch.ones(out.size(dim=0), 1).to(t.device) * t
        out = out + self.dense1(tt)[:, :, None, None]  # 128, 256, 1, 1
        ###
        out = self.norm2(out)
        out = self.relu(out)
        out = self.conv2(out)
        ### zerobiasall
        # tt = torch.ones(out.size(dim=0), 1).to(t.device) * t
        out = out + self.dense2(tt)[:, :, None, None]  # 128, 256, 1, 1
        ###
        out = self.norm3(out)
        return out


### my
class ODENetzerobiasnone(nn.Module):
    def __init__(
        self,
        in_ch,
        out=10,
        n_filters=64,
        downsample="residual",
        method="dopri5",
        tol=1e-3,
        adjoint=False,
        t1=1,
        dropout=0,
        norm="group",
    ):
        super(ODENetzerobiasnone, self).__init__()

        common = dict(out_ch=n_filters, norm=norm)
        if downsample == "residual":
            self.downsample = ResDownsample(in_ch, **common)
        elif downsample == "convolution":
            self.downsample = ConvDownsample(in_ch, **common)
        elif downsample == "minimal":
            self.downsample = MinimalConvDownsample(in_ch, **common)
        elif downsample == "one-shot":
            self.downsample = OneShotDownsample(in_ch, **common)
        elif downsample == "ode":
            self.downsample = ODEDownsample(
                in_ch, adjoint=adjoint, t1=t1, tol=tol, method=method, **common
            )
        elif downsample == "ode2":
            self.downsample = ODEDownsample2(
                in_ch, adjoint=adjoint, t1=t1, tol=tol, method=method, **common
            )

        self.odeblock = ODEBlockzerobiasnone(
            n_filters=n_filters,
            tol=tol,
            adjoint=adjoint,
            t1=t1,
            method=method,
            norm=norm,
        )
        self.classifier = FCClassifier(
            in_ch=n_filters, out=out, dropout=dropout, norm=norm
        )

    def forward(self, x):
        out = []
        x = self.downsample(x)
        if isinstance(x, (tuple, list)):
            (
                f,
                x,
            ) = x  # first elements are features, second is output to continue the forward
            if isinstance(
                self.classifier.module[-1], nn.Sequential
            ):  # no classification to be performed, apply GAP and return
                f = torch.stack(
                    [fi.mean(-1).mean(-1) for fi in f]
                )  # global avg pooling
            else:  # we want to apply the classifier to none features
                f = torch.stack([self.classifier(fi) for fi in f])
            out.append(f)

        x = self.odeblock(x)
        if x.dim() > 4:
            x = torch.stack([self.classifier(xi) for xi in x])
        else:
            x = self.classifier(x)

        out.append(x)
        out = torch.cat(out)
        return out

    def to_features_extractor(self, keep_pool=True):  # ugly hack
        if isinstance(self.downsample, (ODEDownsample, ODEDownsample2)):
            self.downsample.odeblock.return_last_only = False
        self.odeblock.return_last_only = False  # returns dynamic @ multiple timestamps
        if keep_pool:
            # remove last classification layer but maintain norm, relu and global avg pooling
            self.classifier.module[-1] = nn.Sequential()
        else:
            self.classifier = nn.Sequential(
                *list(self.classifier.module.children())[:2]
            )

    def nfe(self, reset=False):
        nfe = self.odeblock.nfe
        if reset:
            self.odeblock.nfe = 0
        return nfe


### my
class ODEBlockzerobiasnone(nn.Module):
    def __init__(
        self, n_filters=64, tol=1e-3, method="dopri5", adjoint=False, t1=1, norm="group"
    ):
        super(ODEBlockzerobiasnone, self).__init__()
        self.odefunc = ODEfunczerobiasnone(n_filters, norm=norm)
        self.t1 = t1
        self.tol = tol
        self.method = method
        self.odeint = odeint_adjoint if adjoint else odeint
        self.return_last_only = True

    def forward(self, x):
        if self.integration_time is None:
            return x

        self.integration_time = self.integration_time.type_as(x)
        out = self.odeint(
            self.odefunc,
            x,
            self.integration_time,
            method=self.method,
            rtol=self.tol,
            atol=self.tol,
        )
        if self.return_last_only:
            out = out[-1]  # dynamics @ t=t1
        return out

    @property
    def nfe(self):
        return self.odefunc.nfe

    @nfe.setter
    def nfe(self, value):
        self.odefunc.nfe = value

    @property
    def t1(self):
        return self.integration_time[1]

    @t1.setter
    def t1(self, value):
        if isinstance(value, (int, float)):
            if value == 0:
                self.integration_time = None
            else:
                self.integration_time = torch.tensor([0, value], dtype=torch.float32)

        elif isinstance(value, (list, tuple, torch.Tensor)):
            if isinstance(value, tuple):
                value = list(value)
            if isinstance(value, torch.Tensor):
                value = value.tolist()
            if value[0] != 0:
                print(value[0])
                value = [
                    0,
                ] + value

            self.integration_time = torch.tensor(value, dtype=torch.float32)
        else:
            raise ValueError("Argument must be a scalar, a list, or a tensor")


### my
class ODEfunczerobiasnone(nn.Module):
    def __init__(self, dim, norm="group"):
        super(ODEfunczerobiasnone, self).__init__()
        norm = normalization(norm)
        self.norm1 = norm(dim)
        self.relu = nn.ReLU()
        self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1)
        self.norm2 = norm(dim)
        self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1)
        self.norm3 = norm(dim)
        self.nfe = 0

        ### my
        self.dense1 = nn.Linear(1, dim)
        self.dense2 = nn.Linear(1, dim)

        ### zerobiasnone
        # nn.init.zeros_(self.conv1.bias)
        # nn.init.zeros_(self.conv2.bias)
        # nn.init.zeros_(self.dense1.bias)
        # nn.init.zeros_(self.dense2.bias)
        ###

    def forward(self, t, x):
        self.nfe += 1
        out = self.norm1(x)
        out = self.relu(out)
        out = self.conv1(out)
        ### zerobiasnone
        tt = torch.ones(out.size(dim=0), 1).to(t.device) * t
        out = out + self.dense1(tt)[:, :, None, None]  # 128, 256, 1, 1
        ###
        out = self.norm2(out)
        out = self.relu(out)
        out = self.conv2(out)
        ### zerobiasnone
        # tt = torch.ones(out.size(dim=0), 1).to(t.device) * t
        out = out + self.dense2(tt)[:, :, None, None]  # 128, 256, 1, 1
        ###
        out = self.norm3(out)
        return out


### my
class ODENetzerobiasconv(nn.Module):
    def __init__(
        self,
        in_ch,
        out=10,
        n_filters=64,
        downsample="residual",
        method="dopri5",
        tol=1e-3,
        adjoint=False,
        t1=1,
        dropout=0,
        norm="group",
    ):
        super(ODENetzerobiasconv, self).__init__()

        common = dict(out_ch=n_filters, norm=norm)
        if downsample == "residual":
            self.downsample = ResDownsample(in_ch, **common)
        elif downsample == "convolution":
            self.downsample = ConvDownsample(in_ch, **common)
        elif downsample == "minimal":
            self.downsample = MinimalConvDownsample(in_ch, **common)
        elif downsample == "one-shot":
            self.downsample = OneShotDownsample(in_ch, **common)
        elif downsample == "ode":
            self.downsample = ODEDownsample(
                in_ch, adjoint=adjoint, t1=t1, tol=tol, method=method, **common
            )
        elif downsample == "ode2":
            self.downsample = ODEDownsample2(
                in_ch, adjoint=adjoint, t1=t1, tol=tol, method=method, **common
            )

        self.odeblock = ODEBlockzerobiasconv(
            n_filters=n_filters,
            tol=tol,
            adjoint=adjoint,
            t1=t1,
            method=method,
            norm=norm,
        )
        self.classifier = FCClassifier(
            in_ch=n_filters, out=out, dropout=dropout, norm=norm
        )

    def forward(self, x):
        out = []
        x = self.downsample(x)
        if isinstance(x, (tuple, list)):
            (
                f,
                x,
            ) = x  # first elements are features, second is output to continue the forward
            if isinstance(
                self.classifier.module[-1], nn.Sequential
            ):  # no classification to be performed, apply GAP and return
                f = torch.stack(
                    [fi.mean(-1).mean(-1) for fi in f]
                )  # global avg pooling
            else:  # we want to apply the classifier to all features
                f = torch.stack([self.classifier(fi) for fi in f])
            out.append(f)

        x = self.odeblock(x)
        if x.dim() > 4:
            x = torch.stack([self.classifier(xi) for xi in x])
        else:
            x = self.classifier(x)

        out.append(x)
        out = torch.cat(out)
        return out

    def to_features_extractor(self, keep_pool=True):  # ugly hack
        if isinstance(self.downsample, (ODEDownsample, ODEDownsample2)):
            self.downsample.odeblock.return_last_only = False
        self.odeblock.return_last_only = False  # returns dynamic @ multiple timestamps
        if keep_pool:
            # remove last classification layer but maintain norm, relu and global avg pooling
            self.classifier.module[-1] = nn.Sequential()
        else:
            self.classifier = nn.Sequential(
                *list(self.classifier.module.children())[:2]
            )

    def nfe(self, reset=False):
        nfe = self.odeblock.nfe
        if reset:
            self.odeblock.nfe = 0
        return nfe


### my
class ODEBlockzerobiasconv(nn.Module):
    def __init__(
        self, n_filters=64, tol=1e-3, method="dopri5", adjoint=False, t1=1, norm="group"
    ):
        super(ODEBlockzerobiasconv, self).__init__()
        self.odefunc = ODEfunczerobiasconv(n_filters, norm=norm)
        self.t1 = t1
        self.tol = tol
        self.method = method
        self.odeint = odeint_adjoint if adjoint else odeint
        self.return_last_only = True

    def forward(self, x):
        if self.integration_time is None:
            return x

        self.integration_time = self.integration_time.type_as(x)
        out = self.odeint(
            self.odefunc,
            x,
            self.integration_time,
            method=self.method,
            rtol=self.tol,
            atol=self.tol,
        )
        if self.return_last_only:
            out = out[-1]  # dynamics @ t=t1
        return out

    @property
    def nfe(self):
        return self.odefunc.nfe

    @nfe.setter
    def nfe(self, value):
        self.odefunc.nfe = value

    @property
    def t1(self):
        return self.integration_time[1]

    @t1.setter
    def t1(self, value):
        if isinstance(value, (int, float)):
            if value == 0:
                self.integration_time = None
            else:
                self.integration_time = torch.tensor([0, value], dtype=torch.float32)

        elif isinstance(value, (list, tuple, torch.Tensor)):
            if isinstance(value, tuple):
                value = list(value)
            if isinstance(value, torch.Tensor):
                value = value.tolist()
            if value[0] != 0:
                print(value[0])
                value = [
                    0,
                ] + value

            self.integration_time = torch.tensor(value, dtype=torch.float32)
        else:
            raise ValueError("Argument must be a scalar, a list, or a tensor")


### my
class ODEfunczerobiasconv(nn.Module):
    def __init__(self, dim, norm="group"):
        super(ODEfunczerobiasconv, self).__init__()
        norm = normalization(norm)
        self.norm1 = norm(dim)
        self.relu = nn.ReLU()
        self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1)
        self.norm2 = norm(dim)
        self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1)
        self.norm3 = norm(dim)
        self.nfe = 0

        ### my
        self.dense1 = nn.Linear(1, dim)
        self.dense2 = nn.Linear(1, dim)

        ### zerobiasconv
        nn.init.zeros_(self.conv1.bias)
        nn.init.zeros_(self.conv2.bias)
        ### nn.init.zeros_(self.dense1.bias)
        ### nn.init.zeros_(self.dense2.bias)
        ###

    def forward(self, t, x):
        self.nfe += 1
        out = self.norm1(x)
        out = self.relu(out)
        out = self.conv1(out)
        ### zerobiasconv
        tt = torch.ones(out.size(dim=0), 1).to(t.device) * t
        out = out + self.dense1(tt)[:, :, None, None]  # 128, 256, 1, 1
        ###
        out = self.norm2(out)
        out = self.relu(out)
        out = self.conv2(out)
        ### zerobiasconv
        # tt = torch.ones(out.size(dim=0), 1).to(t.device) * t
        out = out + self.dense2(tt)[:, :, None, None]  # 128, 256, 1, 1
        ###
        out = self.norm3(out)
        return out


### my
class ODENetzerobiasdense(nn.Module):
    def __init__(
        self,
        in_ch,
        out=10,
        n_filters=64,
        downsample="residual",
        method="dopri5",
        tol=1e-3,
        adjoint=False,
        t1=1,
        dropout=0,
        norm="group",
    ):
        super(ODENetzerobiasdense, self).__init__()

        common = dict(out_ch=n_filters, norm=norm)
        if downsample == "residual":
            self.downsample = ResDownsample(in_ch, **common)
        elif downsample == "convolution":
            self.downsample = ConvDownsample(in_ch, **common)
        elif downsample == "minimal":
            self.downsample = MinimalConvDownsample(in_ch, **common)
        elif downsample == "one-shot":
            self.downsample = OneShotDownsample(in_ch, **common)
        elif downsample == "ode":
            self.downsample = ODEDownsample(
                in_ch, adjoint=adjoint, t1=t1, tol=tol, method=method, **common
            )
        elif downsample == "ode2":
            self.downsample = ODEDownsample2(
                in_ch, adjoint=adjoint, t1=t1, tol=tol, method=method, **common
            )

        self.odeblock = ODEBlockzerobiasdense(
            n_filters=n_filters,
            tol=tol,
            adjoint=adjoint,
            t1=t1,
            method=method,
            norm=norm,
        )
        self.classifier = FCClassifier(
            in_ch=n_filters, out=out, dropout=dropout, norm=norm
        )

    def forward(self, x):
        out = []
        x = self.downsample(x)
        if isinstance(x, (tuple, list)):
            (
                f,
                x,
            ) = x  # first elements are features, second is output to continue the forward
            if isinstance(
                self.classifier.module[-1], nn.Sequential
            ):  # no classification to be performed, apply GAP and return
                f = torch.stack(
                    [fi.mean(-1).mean(-1) for fi in f]
                )  # global avg pooling
            else:  # we want to apply the classifier to all features
                f = torch.stack([self.classifier(fi) for fi in f])
            out.append(f)

        x = self.odeblock(x)
        if x.dim() > 4:
            x = torch.stack([self.classifier(xi) for xi in x])
        else:
            x = self.classifier(x)

        out.append(x)
        out = torch.cat(out)
        return out

    def to_features_extractor(self, keep_pool=True):  # ugly hack
        if isinstance(self.downsample, (ODEDownsample, ODEDownsample2)):
            self.downsample.odeblock.return_last_only = False
        self.odeblock.return_last_only = False  # returns dynamic @ multiple timestamps
        if keep_pool:
            # remove last classification layer but maintain norm, relu and global avg pooling
            self.classifier.module[-1] = nn.Sequential()
        else:
            self.classifier = nn.Sequential(
                *list(self.classifier.module.children())[:2]
            )

    def nfe(self, reset=False):
        nfe = self.odeblock.nfe
        if reset:
            self.odeblock.nfe = 0
        return nfe


### my
class ODEBlockzerobiasdense(nn.Module):
    def __init__(
        self, n_filters=64, tol=1e-3, method="dopri5", adjoint=False, t1=1, norm="group"
    ):
        super(ODEBlockzerobiasdense, self).__init__()
        self.odefunc = ODEfunczerobiasdense(n_filters, norm=norm)
        self.t1 = t1
        self.tol = tol
        self.method = method
        self.odeint = odeint_adjoint if adjoint else odeint
        self.return_last_only = True

    def forward(self, x):
        if self.integration_time is None:
            return x

        self.integration_time = self.integration_time.type_as(x)
        out = self.odeint(
            self.odefunc,
            x,
            self.integration_time,
            method=self.method,
            rtol=self.tol,
            atol=self.tol,
        )
        if self.return_last_only:
            out = out[-1]  # dynamics @ t=t1
        return out

    @property
    def nfe(self):
        return self.odefunc.nfe

    @nfe.setter
    def nfe(self, value):
        self.odefunc.nfe = value

    @property
    def t1(self):
        return self.integration_time[1]

    @t1.setter
    def t1(self, value):
        if isinstance(value, (int, float)):
            if value == 0:
                self.integration_time = None
            else:
                self.integration_time = torch.tensor([0, value], dtype=torch.float32)

        elif isinstance(value, (list, tuple, torch.Tensor)):
            if isinstance(value, tuple):
                value = list(value)
            if isinstance(value, torch.Tensor):
                value = value.tolist()
            if value[0] != 0:
                print(value[0])
                value = [
                    0,
                ] + value

            self.integration_time = torch.tensor(value, dtype=torch.float32)
        else:
            raise ValueError("Argument must be a scalar, a list, or a tensor")


### my
class ODEfunczerobiasdense(nn.Module):
    def __init__(self, dim, norm="group"):
        super(ODEfunczerobiasdense, self).__init__()
        norm = normalization(norm)
        self.norm1 = norm(dim)
        self.relu = nn.ReLU()
        self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1)
        self.norm2 = norm(dim)
        self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1)
        self.norm3 = norm(dim)
        self.nfe = 0

        ### my
        self.dense1 = nn.Linear(1, dim)
        self.dense2 = nn.Linear(1, dim)

        ### zerobiasdense
        ### nn.init.zeros_(self.conv1.bias)
        ### nn.init.zeros_(self.conv2.bias)
        nn.init.zeros_(self.dense1.bias)
        nn.init.zeros_(self.dense2.bias)
        ###

    def forward(self, t, x):
        self.nfe += 1
        out = self.norm1(x)
        out = self.relu(out)
        out = self.conv1(out)
        ### zerobiasdense
        tt = torch.ones(out.size(dim=0), 1).to(t.device) * t
        out = out + self.dense1(tt)[:, :, None, None]  # 128, 256, 1, 1
        ###
        out = self.norm2(out)
        out = self.relu(out)
        out = self.conv2(out)
        ### zerobiasdense
        # tt = torch.ones(out.size(dim=0), 1).to(t.device) * t
        out = out + self.dense2(tt)[:, :, None, None]  # 128, 256, 1, 1
        ###
        out = self.norm3(out)
        return out


### my
class ODENetinit3ep0(nn.Module):
    def __init__(
        self,
        in_ch,
        out=10,
        n_filters=64,
        downsample="residual",
        method="dopri5",
        tol=1e-3,
        adjoint=False,
        t1=1,
        dropout=0,
        norm="group",
    ):
        super(ODENetinit3ep0, self).__init__()

        common = dict(out_ch=n_filters, norm=norm)
        if downsample == "residual":
            self.downsample = ResDownsample(in_ch, **common)
        elif downsample == "convolution":
            self.downsample = ConvDownsample(in_ch, **common)
        elif downsample == "minimal":
            self.downsample = MinimalConvDownsample(in_ch, **common)
        elif downsample == "one-shot":
            self.downsample = OneShotDownsample(in_ch, **common)
        elif downsample == "ode":
            self.downsample = ODEDownsample(
                in_ch, adjoint=adjoint, t1=t1, tol=tol, method=method, **common
            )
        elif downsample == "ode2":
            self.downsample = ODEDownsample2(
                in_ch, adjoint=adjoint, t1=t1, tol=tol, method=method, **common
            )

        self.odeblock = ODEBlockinit3ep0(
            n_filters=n_filters,
            tol=tol,
            adjoint=adjoint,
            t1=t1,
            method=method,
            norm=norm,
        )
        self.classifier = FCClassifier(
            in_ch=n_filters, out=out, dropout=dropout, norm=norm
        )

    def forward(self, x):
        out = []
        x = self.downsample(x)
        if isinstance(x, (tuple, list)):
            (
                f,
                x,
            ) = x  # first elements are features, second is output to continue the forward
            if isinstance(
                self.classifier.module[-1], nn.Sequential
            ):  # no classification to be performed, apply GAP and return
                f = torch.stack(
                    [fi.mean(-1).mean(-1) for fi in f]
                )  # global avg pooling
            else:  # we want to apply the classifier to all features
                f = torch.stack([self.classifier(fi) for fi in f])
            out.append(f)

        x = self.odeblock(x)
        if x.dim() > 4:
            x = torch.stack([self.classifier(xi) for xi in x])
        else:
            x = self.classifier(x)

        out.append(x)
        out = torch.cat(out)
        return out

    def to_features_extractor(self, keep_pool=True):  # ugly hack
        if isinstance(self.downsample, (ODEDownsample, ODEDownsample2)):
            self.downsample.odeblock.return_last_only = False
        self.odeblock.return_last_only = False  # returns dynamic @ multiple timestamps
        if keep_pool:
            # remove last classification layer but maintain norm, relu and global avg pooling
            self.classifier.module[-1] = nn.Sequential()
        else:
            self.classifier = nn.Sequential(
                *list(self.classifier.module.children())[:2]
            )

    def nfe(self, reset=False):
        nfe = self.odeblock.nfe
        if reset:
            self.odeblock.nfe = 0
        return nfe


### my
class ODEBlockinit3ep0(nn.Module):
    def __init__(
        self, n_filters=64, tol=1e-3, method="dopri5", adjoint=False, t1=1, norm="group"
    ):
        super(ODEBlockinit3ep0, self).__init__()
        self.odefunc = ODEfuncinit3ep0(n_filters, norm=norm)
        self.t1 = t1
        self.tol = tol
        self.method = method
        self.odeint = odeint_adjoint if adjoint else odeint
        self.return_last_only = True

    def forward(self, x):
        if self.integration_time is None:
            return x

        self.integration_time = self.integration_time.type_as(x)
        out = self.odeint(
            self.odefunc,
            x,
            self.integration_time,
            method=self.method,
            rtol=self.tol,
            atol=self.tol,
        )
        if self.return_last_only:
            out = out[-1]  # dynamics @ t=t1
        return out

    @property
    def nfe(self):
        return self.odefunc.nfe

    @nfe.setter
    def nfe(self, value):
        self.odefunc.nfe = value

    @property
    def t1(self):
        return self.integration_time[1]

    @t1.setter
    def t1(self, value):
        if isinstance(value, (int, float)):
            if value == 0:
                self.integration_time = None
            else:
                self.integration_time = torch.tensor([0, value], dtype=torch.float32)

        elif isinstance(value, (list, tuple, torch.Tensor)):
            if isinstance(value, tuple):
                value = list(value)
            if isinstance(value, torch.Tensor):
                value = value.tolist()
            if value[0] != 0:
                print(value[0])
                value = [
                    0,
                ] + value

            self.integration_time = torch.tensor(value, dtype=torch.float32)
        else:
            raise ValueError("Argument must be a scalar, a list, or a tensor")


### my
class ODEfuncinit3ep0(nn.Module):
    def __init__(self, dim, norm="group"):
        super(ODEfuncinit3ep0, self).__init__()
        norm = normalization(norm)
        self.norm1 = norm(dim)
        self.relu = nn.ReLU()
        self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1)
        self.norm2 = norm(dim)
        self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1)
        self.norm3 = norm(dim)
        self.nfe = 0

        ### my
        self.dense1 = nn.Linear(1, dim)
        self.dense2 = nn.Linear(1, dim)

        nn.init.normal_(self.dense1.weight, mean=0.0, std=3e0)
        nn.init.normal_(self.dense2.weight, mean=0.0, std=3e0)
        ###

    def forward(self, t, x):
        self.nfe += 1
        out = self.norm1(x)
        out = self.relu(out)
        out = self.conv1(out)
        ### init3ep0
        tt = torch.ones(out.size(dim=0), 1).to(t.device) * t
        out = out + self.dense1(tt)[:, :, None, None]  # 128, 256, 1, 1
        ###
        out = self.norm2(out)
        out = self.relu(out)
        out = self.conv2(out)
        ### init3ep0
        # tt = torch.ones(out.size(dim=0), 1).to(t.device) * t
        out = out + self.dense2(tt)[:, :, None, None]  # 128, 256, 1, 1
        ###
        out = self.norm3(out)
        return out


### my
class ODENetinit3ep1(nn.Module):
    def __init__(
        self,
        in_ch,
        out=10,
        n_filters=64,
        downsample="residual",
        method="dopri5",
        tol=1e-3,
        adjoint=False,
        t1=1,
        dropout=0,
        norm="group",
    ):
        super(ODENetinit3ep1, self).__init__()

        common = dict(out_ch=n_filters, norm=norm)
        if downsample == "residual":
            self.downsample = ResDownsample(in_ch, **common)
        elif downsample == "convolution":
            self.downsample = ConvDownsample(in_ch, **common)
        elif downsample == "minimal":
            self.downsample = MinimalConvDownsample(in_ch, **common)
        elif downsample == "one-shot":
            self.downsample = OneShotDownsample(in_ch, **common)
        elif downsample == "ode":
            self.downsample = ODEDownsample(
                in_ch, adjoint=adjoint, t1=t1, tol=tol, method=method, **common
            )
        elif downsample == "ode2":
            self.downsample = ODEDownsample2(
                in_ch, adjoint=adjoint, t1=t1, tol=tol, method=method, **common
            )

        self.odeblock = ODEBlockinit3ep1(
            n_filters=n_filters,
            tol=tol,
            adjoint=adjoint,
            t1=t1,
            method=method,
            norm=norm,
        )
        self.classifier = FCClassifier(
            in_ch=n_filters, out=out, dropout=dropout, norm=norm
        )

    def forward(self, x):
        out = []
        x = self.downsample(x)
        if isinstance(x, (tuple, list)):
            (
                f,
                x,
            ) = x  # first elements are features, second is output to continue the forward
            if isinstance(
                self.classifier.module[-1], nn.Sequential
            ):  # no classification to be performed, apply GAP and return
                f = torch.stack(
                    [fi.mean(-1).mean(-1) for fi in f]
                )  # global avg pooling
            else:  # we want to apply the classifier to all features
                f = torch.stack([self.classifier(fi) for fi in f])
            out.append(f)

        x = self.odeblock(x)
        if x.dim() > 4:
            x = torch.stack([self.classifier(xi) for xi in x])
        else:
            x = self.classifier(x)

        out.append(x)
        out = torch.cat(out)
        return out

    def to_features_extractor(self, keep_pool=True):  # ugly hack
        if isinstance(self.downsample, (ODEDownsample, ODEDownsample2)):
            self.downsample.odeblock.return_last_only = False
        self.odeblock.return_last_only = False  # returns dynamic @ multiple timestamps
        if keep_pool:
            # remove last classification layer but maintain norm, relu and global avg pooling
            self.classifier.module[-1] = nn.Sequential()
        else:
            self.classifier = nn.Sequential(
                *list(self.classifier.module.children())[:2]
            )

    def nfe(self, reset=False):
        nfe = self.odeblock.nfe
        if reset:
            self.odeblock.nfe = 0
        return nfe


### my
class ODEBlockinit3ep1(nn.Module):
    def __init__(
        self, n_filters=64, tol=1e-3, method="dopri5", adjoint=False, t1=1, norm="group"
    ):
        super(ODEBlockinit3ep1, self).__init__()
        self.odefunc = ODEfuncinit3ep1(n_filters, norm=norm)
        self.t1 = t1
        self.tol = tol
        self.method = method
        self.odeint = odeint_adjoint if adjoint else odeint
        self.return_last_only = True

    def forward(self, x):
        if self.integration_time is None:
            return x

        self.integration_time = self.integration_time.type_as(x)
        out = self.odeint(
            self.odefunc,
            x,
            self.integration_time,
            method=self.method,
            rtol=self.tol,
            atol=self.tol,
        )
        if self.return_last_only:
            out = out[-1]  # dynamics @ t=t1
        return out

    @property
    def nfe(self):
        return self.odefunc.nfe

    @nfe.setter
    def nfe(self, value):
        self.odefunc.nfe = value

    @property
    def t1(self):
        return self.integration_time[1]

    @t1.setter
    def t1(self, value):
        if isinstance(value, (int, float)):
            if value == 0:
                self.integration_time = None
            else:
                self.integration_time = torch.tensor([0, value], dtype=torch.float32)

        elif isinstance(value, (list, tuple, torch.Tensor)):
            if isinstance(value, tuple):
                value = list(value)
            if isinstance(value, torch.Tensor):
                value = value.tolist()
            if value[0] != 0:
                print(value[0])
                value = [
                    0,
                ] + value

            self.integration_time = torch.tensor(value, dtype=torch.float32)
        else:
            raise ValueError("Argument must be a scalar, a list, or a tensor")


### my
class ODEfuncinit3ep1(nn.Module):
    def __init__(self, dim, norm="group"):
        super(ODEfuncinit3ep1, self).__init__()
        norm = normalization(norm)
        self.norm1 = norm(dim)
        self.relu = nn.ReLU()
        self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1)
        self.norm2 = norm(dim)
        self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1)
        self.norm3 = norm(dim)
        self.nfe = 0

        ### my
        self.dense1 = nn.Linear(1, dim)
        self.dense2 = nn.Linear(1, dim)

        nn.init.normal_(self.dense1.weight, mean=0.0, std=3e1)
        nn.init.normal_(self.dense2.weight, mean=0.0, std=3e1)
        ###

    def forward(self, t, x):
        self.nfe += 1
        out = self.norm1(x)
        out = self.relu(out)
        out = self.conv1(out)
        ### init3ep1
        tt = torch.ones(out.size(dim=0), 1).to(t.device) * t
        out = out + self.dense1(tt)[:, :, None, None]  # 128, 256, 1, 1
        ###
        out = self.norm2(out)
        out = self.relu(out)
        out = self.conv2(out)
        ### init3ep1
        # tt = torch.ones(out.size(dim=0), 1).to(t.device) * t
        out = out + self.dense2(tt)[:, :, None, None]  # 128, 256, 1, 1
        ###
        out = self.norm3(out)
        return out


### my
class ODENetinit3ep2(nn.Module):
    def __init__(
        self,
        in_ch,
        out=10,
        n_filters=64,
        downsample="residual",
        method="dopri5",
        tol=1e-3,
        adjoint=False,
        t1=1,
        dropout=0,
        norm="group",
    ):
        super(ODENetinit3ep2, self).__init__()

        common = dict(out_ch=n_filters, norm=norm)
        if downsample == "residual":
            self.downsample = ResDownsample(in_ch, **common)
        elif downsample == "convolution":
            self.downsample = ConvDownsample(in_ch, **common)
        elif downsample == "minimal":
            self.downsample = MinimalConvDownsample(in_ch, **common)
        elif downsample == "one-shot":
            self.downsample = OneShotDownsample(in_ch, **common)
        elif downsample == "ode":
            self.downsample = ODEDownsample(
                in_ch, adjoint=adjoint, t1=t1, tol=tol, method=method, **common
            )
        elif downsample == "ode2":
            self.downsample = ODEDownsample2(
                in_ch, adjoint=adjoint, t1=t1, tol=tol, method=method, **common
            )

        self.odeblock = ODEBlockinit3ep2(
            n_filters=n_filters,
            tol=tol,
            adjoint=adjoint,
            t1=t1,
            method=method,
            norm=norm,
        )
        self.classifier = FCClassifier(
            in_ch=n_filters, out=out, dropout=dropout, norm=norm
        )

    def forward(self, x):
        out = []
        x = self.downsample(x)
        if isinstance(x, (tuple, list)):
            (
                f,
                x,
            ) = x  # first elements are features, second is output to continue the forward
            if isinstance(
                self.classifier.module[-1], nn.Sequential
            ):  # no classification to be performed, apply GAP and return
                f = torch.stack(
                    [fi.mean(-1).mean(-1) for fi in f]
                )  # global avg pooling
            else:  # we want to apply the classifier to all features
                f = torch.stack([self.classifier(fi) for fi in f])
            out.append(f)

        x = self.odeblock(x)
        if x.dim() > 4:
            x = torch.stack([self.classifier(xi) for xi in x])
        else:
            x = self.classifier(x)

        out.append(x)
        out = torch.cat(out)
        return out

    def to_features_extractor(self, keep_pool=True):  # ugly hack
        if isinstance(self.downsample, (ODEDownsample, ODEDownsample2)):
            self.downsample.odeblock.return_last_only = False
        self.odeblock.return_last_only = False  # returns dynamic @ multiple timestamps
        if keep_pool:
            # remove last classification layer but maintain norm, relu and global avg pooling
            self.classifier.module[-1] = nn.Sequential()
        else:
            self.classifier = nn.Sequential(
                *list(self.classifier.module.children())[:2]
            )

    def nfe(self, reset=False):
        nfe = self.odeblock.nfe
        if reset:
            self.odeblock.nfe = 0
        return nfe


### my
class ODEBlockinit3ep2(nn.Module):
    def __init__(
        self, n_filters=64, tol=1e-3, method="dopri5", adjoint=False, t1=1, norm="group"
    ):
        super(ODEBlockinit3ep2, self).__init__()
        self.odefunc = ODEfuncinit3ep2(n_filters, norm=norm)
        self.t1 = t1
        self.tol = tol
        self.method = method
        self.odeint = odeint_adjoint if adjoint else odeint
        self.return_last_only = True

    def forward(self, x):
        if self.integration_time is None:
            return x

        self.integration_time = self.integration_time.type_as(x)
        out = self.odeint(
            self.odefunc,
            x,
            self.integration_time,
            method=self.method,
            rtol=self.tol,
            atol=self.tol,
        )
        if self.return_last_only:
            out = out[-1]  # dynamics @ t=t1
        return out

    @property
    def nfe(self):
        return self.odefunc.nfe

    @nfe.setter
    def nfe(self, value):
        self.odefunc.nfe = value

    @property
    def t1(self):
        return self.integration_time[1]

    @t1.setter
    def t1(self, value):
        if isinstance(value, (int, float)):
            if value == 0:
                self.integration_time = None
            else:
                self.integration_time = torch.tensor([0, value], dtype=torch.float32)

        elif isinstance(value, (list, tuple, torch.Tensor)):
            if isinstance(value, tuple):
                value = list(value)
            if isinstance(value, torch.Tensor):
                value = value.tolist()
            if value[0] != 0:
                print(value[0])
                value = [
                    0,
                ] + value

            self.integration_time = torch.tensor(value, dtype=torch.float32)
        else:
            raise ValueError("Argument must be a scalar, a list, or a tensor")


### my
class ODEfuncinit3ep2(nn.Module):
    def __init__(self, dim, norm="group"):
        super(ODEfuncinit3ep2, self).__init__()
        norm = normalization(norm)
        self.norm1 = norm(dim)
        self.relu = nn.ReLU()
        self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1)
        self.norm2 = norm(dim)
        self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1)
        self.norm3 = norm(dim)
        self.nfe = 0

        ### my
        self.dense1 = nn.Linear(1, dim)
        self.dense2 = nn.Linear(1, dim)

        nn.init.normal_(self.dense1.weight, mean=0.0, std=3e2)
        nn.init.normal_(self.dense2.weight, mean=0.0, std=3e2)
        ###

    def forward(self, t, x):
        self.nfe += 1
        out = self.norm1(x)
        out = self.relu(out)
        out = self.conv1(out)
        ### init3ep2
        tt = torch.ones(out.size(dim=0), 1).to(t.device) * t
        out = out + self.dense1(tt)[:, :, None, None]  # 128, 256, 1, 1
        ###
        out = self.norm2(out)
        out = self.relu(out)
        out = self.conv2(out)
        ### init3ep2
        # tt = torch.ones(out.size(dim=0), 1).to(t.device) * t
        out = out + self.dense2(tt)[:, :, None, None]  # 128, 256, 1, 1
        ###
        out = self.norm3(out)
        return out


class ResNet(nn.Module):
    def __init__(
        self,
        in_ch,
        out=10,
        n_filters=64,
        downsample="residual",
        dropout=0,
        norm="group",
    ):
        super(ResNet, self).__init__()

        common = dict(out_ch=n_filters, norm=norm)
        if downsample == "residual":
            self.downsample = ResDownsample(in_ch, **common)
        elif downsample == "convolution":
            self.downsample = ConvDownsample(in_ch, **common)
        elif downsample == "minimal":
            self.downsample = MinimalConvDownsample(in_ch, **common)
        elif downsample == "one-shot":
            self.downsample = OneShotDownsample(in_ch, **common)

        self.features = nn.Sequential(
            *[ResBlock(n_filters, n_filters) for _ in range(6)]
        )
        self.classifier = FCClassifier(n_filters, out=out, dropout=dropout, norm=norm)
        self._extract_features = False

    def to_features_extractor(self, keep_pool=True):  # ugly hack
        if keep_pool:
            # remove last classification layer but maintain norm, relu and global avg pooling
            self.classifier.module[-1] = nn.Sequential()
        else:
            self.classifier = nn.Sequential(
                *list(self.classifier.module.children())[:2]
            )
        self._extract_features = True
        self._tmp_features = [None,] * (
            len(self.features) + 1
        )  # we keep also the first input

        def hooks(idx):
            def __hook(m, i, o):
                self._tmp_features[idx] = o.data

            return __hook

        self.downsample.register_forward_hook(hooks(0))
        for n, block in enumerate(self.features):
            block.register_forward_hook(hooks(n + 1))

    def forward(self, x):
        x = self.downsample(x)
        x = self.features(x)
        if self._extract_features:
            x = torch.stack([self.classifier(xi) for xi in self._tmp_features])
        else:
            x = self.classifier(x)
        return x

    def nfe(self, reset=False):
        return 0


"""
   Initial Downsample Blocks
"""


class OneShotDownsample(nn.Module):
    def __init__(self, in_ch, out_ch=64, **kwargs):
        super(OneShotDownsample, self).__init__()
        self.module = nn.Conv2d(in_ch, out_ch, 4, 2, 1)

    def forward(self, *input):
        return self.module(*input)


class MinimalConvDownsample(nn.Module):
    def __init__(self, in_ch, out_ch=64, norm="group"):
        super(MinimalConvDownsample, self).__init__()
        norm = normalization(norm)
        self.module = nn.Sequential(
            nn.Conv2d(in_ch, 24, 3, 1),
            norm(24),
            nn.ReLU(inplace=True),
            nn.Conv2d(24, 24, 4, 2, 1),
            norm(24),
            nn.ReLU(inplace=True),
            nn.Conv2d(24, out_ch, 4, 2, 1),
        )

    def forward(self, *input):
        return self.module(*input)


class ConvDownsample(nn.Module):
    def __init__(self, in_ch, out_ch=64, norm="group"):
        super(ConvDownsample, self).__init__()
        norm = normalization(norm)
        self.module = nn.Sequential(
            nn.Conv2d(in_ch, 64, 3, 1),
            norm(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 4, 2, 1),
            norm(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, out_ch, 4, 2, 1),
        )

    def forward(self, *input):
        return self.module(*input)


class ResDownsample(nn.Module):
    def __init__(self, in_ch, out_ch=64, norm="group"):
        super(ResDownsample, self).__init__()
        self.module = nn.Sequential(
            nn.Conv2d(in_ch, 64, 3, 1),
            ResBlock(64, 64, stride=2, downsample=conv1x1(64, 64, 2), norm=norm),
            ResBlock(
                64, out_ch, stride=2, downsample=conv1x1(64, out_ch, 2), norm=norm
            ),
        )

    def forward(self, *input):
        return self.module(*input)


class ODEDownsample(nn.Module):
    def __init__(
        self,
        in_ch,
        out_ch=64,
        method="dopri5",
        adjoint=False,
        t1=1,
        tol=1e-3,
        norm="group",
    ):
        super(ODEDownsample, self).__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 4, 2, 1)  # first downsample
        self.odeblock = ODEBlock(
            n_filters=out_ch, adjoint=adjoint, t1=t1, tol=tol, method=method, norm=norm
        )
        self.maxpool = nn.MaxPool2d(4, 2, 1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.odeblock(x)
        if x.dim() > 4:
            return x, self.maxpool(x[-1])

        x = self.maxpool(x)
        return x


class ODEDownsample2(nn.Module):
    def __init__(
        self,
        in_ch,
        out_ch=64,
        method="dopri5",
        adjoint=False,
        t1=1,
        tol=1e-3,
        norm="group",
    ):
        super(ODEDownsample2, self).__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 4, 2, 1)  # first downsample
        self.odeblock = ODEBlock(
            n_filters=out_ch, adjoint=adjoint, t1=t1, tol=tol, method=method, norm=norm
        )
        self.norm = nn.Sequential(normalization(norm)(out_ch), nn.ReLU(inplace=True))
        self.conv2 = nn.Conv2d(out_ch, out_ch, 4, 2, 1)  # downsample for successive ode

        self.apply_conv = False

    def forward(self, x):
        x = self.conv1(x)
        x = self.odeblock(x)
        if x.dim() > 4:
            x = torch.stack([self.norm(xi) for xi in x])
            if self.apply_conv:
                x = torch.stack([self.conv2(xi) for xi in x])
                return x, x[-1]
            # otherwise apply conv2 only at the last
            return x, self.conv2(x[-1])

        x = self.norm(x)
        x = self.conv2(x)
        return x


"""
    Final FC Module
"""


class FCClassifier(nn.Module):
    def __init__(self, in_ch=64, out=10, dropout=0, norm="group"):
        super(FCClassifier, self).__init__()
        norm = normalization(norm)
        layers = (
            [
                norm(in_ch),
                nn.ReLU(inplace=True),
                nn.AdaptiveAvgPool2d((1, 1)),  # global average pooling
            ]
            + (
                [
                    nn.Dropout(dropout),
                ]
                if dropout
                else []
            )
            + [Flatten(), nn.Linear(in_ch, out)]
        )

        self.module = nn.Sequential(*layers)

    def forward(self, *input):
        return self.module(*input)


"""
    Helper Modules
"""


def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(
        in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False
    )


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


def normalization(norm="group"):
    def _group_norm(dim):
        return nn.GroupNorm(min(32, dim), dim)

    def _batch_norm(dim):
        return nn.BatchNorm2d(dim, track_running_stats=False)

    def _group_norm_256(dim):
        return nn.GroupNorm(min(256, dim), dim)

    def _group_norm_128(dim):
        return nn.GroupNorm(min(128, dim), dim)

    def _group_norm_64(dim):
        return nn.GroupNorm(min(64, dim), dim)

    def _group_norm_32(dim):
        return nn.GroupNorm(min(32, dim), dim)

    def _group_norm_16(dim):
        return nn.GroupNorm(min(16, dim), dim)

    def _group_norm_8(dim):
        return nn.GroupNorm(min(8, dim), dim)

    def _group_norm_4(dim):
        return nn.GroupNorm(min(4, dim), dim)

    def _group_norm_2(dim):
        return nn.GroupNorm(min(2, dim), dim)

    def _group_norm_1(dim):
        return nn.GroupNorm(min(1, dim), dim)

    if norm == "group":
        return _group_norm
    elif norm == "batch":
        return _batch_norm
    elif norm == "G256":
        return _group_norm_256
    elif norm == "G128":
        return _group_norm_128
    elif norm == "G64":
        return _group_norm_64
    elif norm == "G32":
        return _group_norm_32
    elif norm == "G16":
        return _group_norm_16
    elif norm == "G8":
        return _group_norm_8
    elif norm == "G4":
        return _group_norm_4
    elif norm == "G2":
        return _group_norm_2
    elif norm == "G1":
        return _group_norm_1

    raise NotImplementedError("Normalization layer not implemented: {}".format(norm))


class ResBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None, norm="group"):
        super(ResBlock, self).__init__()
        norm = normalization(norm)
        self.norm1 = norm(inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.norm2 = norm(planes)
        self.conv2 = conv3x3(planes, planes)

    def forward(self, x):
        shortcut = x

        out = self.relu(self.norm1(x))

        if self.downsample is not None:
            shortcut = self.downsample(out)

        out = self.conv1(out)
        out = self.norm2(out)
        out = self.relu(out)
        out = self.conv2(out)

        return out + shortcut


class ConcatConv2d(nn.Module):
    def __init__(self, dim_in, dim_out, transpose=False, **kwargs):
        super(ConcatConv2d, self).__init__()
        module = nn.ConvTranspose2d if transpose else nn.Conv2d
        self._layer = module(dim_in + 1, dim_out, **kwargs)

    def forward(self, t, x):
        tt = torch.ones_like(x[:, :1, :, :]) * t
        ttx = torch.cat([tt, x], 1)
        return self._layer(ttx)


class ODEfunc(nn.Module):
    def __init__(self, dim, norm="group"):
        super(ODEfunc, self).__init__()
        norm = normalization(norm)
        self.norm1 = norm(dim)
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = ConcatConv2d(dim, dim, kernel_size=3, stride=1, padding=1)
        self.norm2 = norm(dim)
        self.conv2 = ConcatConv2d(dim, dim, kernel_size=3, stride=1, padding=1)
        self.norm3 = norm(dim)
        self.nfe = 0

    def forward(self, t, x):
        self.nfe += 1
        out = self.norm1(x)
        out = self.relu(out)
        out = self.conv1(t, out)
        out = self.norm2(out)
        out = self.relu(out)
        out = self.conv2(t, out)
        out = self.norm3(out)
        return out


### my
class ODEfuncSpatial(nn.Module):
    def __init__(self, dim, norm="group"):
        super(ODEfuncSpatial, self).__init__()
        norm = normalization(norm)
        self.norm1 = norm(dim)
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = ConcatConv2d(dim, dim, kernel_size=3, stride=1, padding=1)
        self.norm2 = norm(dim)
        self.conv2 = ConcatConv2d(dim, dim, kernel_size=3, stride=1, padding=1)
        self.norm3 = norm(dim)
        self.nfe = 0

        ### my
        H = 8  # 8 for cifar, 16 for tiny-imagenet-200
        self.spatial1 = nn.Parameter(
            torch.normal(mean=0.0, std=1e-1 * torch.ones(1, 1, H, H))
        )  # requires_grad=True by default.
        self.spatial2 = nn.Parameter(
            torch.normal(mean=0.0, std=1e-1 * torch.ones(1, 1, H, H))
        )  # requires_grad=True by default.

    def forward(self, t, x):
        self.nfe += 1
        ### print(x.size())
        ### cifar: x 128 256 8 8
        ### tiny-imagenet-200: 128 256 16 16
        out = self.norm1(x)
        out = self.relu(out)
        ### out = self.conv1(t, out)
        out = self.conv1(t, out) + t * self.spatial1
        out = self.norm2(out)
        out = self.relu(out)
        ### out = self.conv2(t, out)
        out = self.conv2(t, out) + t * self.spatial2
        out = self.norm3(out)
        return out


### my
class ODEfuncSpatialPlusTwoMul(nn.Module):
    def __init__(self, dim, norm="group"):
        super(ODEfuncSpatialPlusTwoMul, self).__init__()
        norm = normalization(norm)
        self.norm1 = norm(dim)
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = ConcatConv2d(dim, dim, kernel_size=3, stride=1, padding=1)
        self.norm2 = norm(dim)
        self.conv2 = ConcatConv2d(dim, dim, kernel_size=3, stride=1, padding=1)
        self.norm3 = norm(dim)
        self.nfe = 0

        ### my
        H = 8  # 8 for cifar, 16 for tiny-imagenet-200
        self.spatial1 = nn.Parameter(
            torch.normal(mean=0.0, std=1e-1 * torch.ones(1, 1, H, H))
        )  # requires_grad=True by default.
        self.spatial2 = nn.Parameter(
            torch.normal(mean=0.0, std=1e-1 * torch.ones(1, 1, H, H))
        )  # requires_grad=True by default.
        self.channel_mul1 = nn.Parameter(
            torch.normal(mean=0.0, std=1e-1 * torch.ones(1, dim, 1, 1))
        )
        self.channel_mul2 = nn.Parameter(
            torch.normal(mean=0.0, std=1e-1 * torch.ones(1, dim, 1, 1))
        )
        self.spatial_mul1 = nn.Parameter(
            torch.normal(mean=0.0, std=1e-1 * torch.ones(1, 1, H, H))
        )
        self.spatial_mul2 = nn.Parameter(
            torch.normal(mean=0.0, std=1e-1 * torch.ones(1, 1, H, H))
        )

    def forward(self, t, x):
        self.nfe += 1
        ### print(x.size())
        ### cifar: x 128 256 8 8
        ### tiny-imagenet-200: 128 256 16 16
        out = self.norm1(x)
        out = self.relu(out)
        ### apply mul here, before the concatconv2d.
        out = torch.mul(
            out,
            torch.exp(torch.Tensor([t]).to(self.channel_mul1.device))
            * self.channel_mul1,
        )
        out = torch.mul(
            out,
            torch.exp(torch.Tensor([t]).to(self.channel_mul1.device))
            * self.spatial_mul1,
        )
        ### out = self.conv1(t, out)
        out = self.conv1(t, out) + t * self.spatial1
        out = self.norm2(out)
        out = self.relu(out)
        ### apply mul here, before the concatconv2d.
        out = torch.mul(
            out,
            torch.exp(torch.Tensor([t]).to(self.channel_mul1.device))
            * self.channel_mul2,
        )
        out = torch.mul(
            out,
            torch.exp(torch.Tensor([t]).to(self.channel_mul1.device))
            * self.spatial_mul2,
        )
        ### out = self.conv2(t, out)
        out = self.conv2(t, out) + t * self.spatial2
        out = self.norm3(out)
        return out


class ODEBlock(nn.Module):
    def __init__(
        self, n_filters=64, tol=1e-3, method="dopri5", adjoint=False, t1=1, norm="group"
    ):
        super(ODEBlock, self).__init__()
        self.odefunc = ODEfunc(n_filters, norm=norm)
        self.t1 = t1
        self.tol = tol
        self.method = method
        self.odeint = odeint_adjoint if adjoint else odeint
        self.return_last_only = True

    def forward(self, x):
        if self.integration_time is None:
            return x

        self.integration_time = self.integration_time.type_as(x)
        out = self.odeint(
            self.odefunc,
            x,
            self.integration_time,
            method=self.method,
            rtol=self.tol,
            atol=self.tol,
        )
        if self.return_last_only:
            out = out[-1]  # dynamics @ t=t1
        return out

    @property
    def nfe(self):
        return self.odefunc.nfe

    @nfe.setter
    def nfe(self, value):
        self.odefunc.nfe = value

    @property
    def t1(self):
        return self.integration_time[1]

    @t1.setter
    def t1(self, value):
        if isinstance(value, (int, float)):
            if value == 0:
                self.integration_time = None
            else:
                self.integration_time = torch.tensor([0, value], dtype=torch.float32)

        elif isinstance(value, (list, tuple, torch.Tensor)):
            if isinstance(value, tuple):
                value = list(value)
            if isinstance(value, torch.Tensor):
                value = value.tolist()
            if value[0] != 0:
                print(value[0])
                value = [
                    0,
                ] + value

            self.integration_time = torch.tensor(value, dtype=torch.float32)
        else:
            raise ValueError("Argument must be a scalar, a list, or a tensor")


### my
class ODEBlockSpatial(nn.Module):
    def __init__(
        self, n_filters=64, tol=1e-3, method="dopri5", adjoint=False, t1=1, norm="group"
    ):
        super(ODEBlockSpatial, self).__init__()
        self.odefunc = ODEfuncSpatial(n_filters, norm=norm)
        self.t1 = t1
        self.tol = tol
        self.method = method
        self.odeint = odeint_adjoint if adjoint else odeint
        self.return_last_only = True

    def forward(self, x):
        if self.integration_time is None:
            return x

        self.integration_time = self.integration_time.type_as(x)
        out = self.odeint(
            self.odefunc,
            x,
            self.integration_time,
            method=self.method,
            rtol=self.tol,
            atol=self.tol,
        )
        if self.return_last_only:
            out = out[-1]  # dynamics @ t=t1
        return out

    @property
    def nfe(self):
        return self.odefunc.nfe

    @nfe.setter
    def nfe(self, value):
        self.odefunc.nfe = value

    @property
    def t1(self):
        return self.integration_time[1]

    @t1.setter
    def t1(self, value):
        if isinstance(value, (int, float)):
            if value == 0:
                self.integration_time = None
            else:
                self.integration_time = torch.tensor([0, value], dtype=torch.float32)

        elif isinstance(value, (list, tuple, torch.Tensor)):
            if isinstance(value, tuple):
                value = list(value)
            if isinstance(value, torch.Tensor):
                value = value.tolist()
            if value[0] != 0:
                print(value[0])
                value = [
                    0,
                ] + value

            self.integration_time = torch.tensor(value, dtype=torch.float32)
        else:
            raise ValueError("Argument must be a scalar, a list, or a tensor")


### my
class ODEBlockSpatialPlusTwoMul(nn.Module):
    def __init__(
        self, n_filters=64, tol=1e-3, method="dopri5", adjoint=False, t1=1, norm="group"
    ):
        super(ODEBlockSpatialPlusTwoMul, self).__init__()
        self.odefunc = ODEfuncSpatialPlusTwoMul(n_filters, norm=norm)
        self.t1 = t1
        self.tol = tol
        self.method = method
        self.odeint = odeint_adjoint if adjoint else odeint
        self.return_last_only = True

    def forward(self, x):
        if self.integration_time is None:
            return x

        self.integration_time = self.integration_time.type_as(x)
        out = self.odeint(
            self.odefunc,
            x,
            self.integration_time,
            method=self.method,
            rtol=self.tol,
            atol=self.tol,
        )
        if self.return_last_only:
            out = out[-1]  # dynamics @ t=t1
        return out

    @property
    def nfe(self):
        return self.odefunc.nfe

    @nfe.setter
    def nfe(self, value):
        self.odefunc.nfe = value

    @property
    def t1(self):
        return self.integration_time[1]

    @t1.setter
    def t1(self, value):
        if isinstance(value, (int, float)):
            if value == 0:
                self.integration_time = None
            else:
                self.integration_time = torch.tensor([0, value], dtype=torch.float32)

        elif isinstance(value, (list, tuple, torch.Tensor)):
            if isinstance(value, tuple):
                value = list(value)
            if isinstance(value, torch.Tensor):
                value = value.tolist()
            if value[0] != 0:
                print(value[0])
                value = [
                    0,
                ] + value

            self.integration_time = torch.tensor(value, dtype=torch.float32)
        else:
            raise ValueError("Argument must be a scalar, a list, or a tensor")


class Flatten(nn.Module):
    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, x):
        shape = torch.prod(torch.tensor(x.shape[1:])).item()
        return x.view(-1, shape)


if __name__ == "__main__":
    net = ODENet(3, downsample="ode", t1=[0.1, 0.2, 0.3, 1]).to("cuda")
    net.to_features_extractor()
    # print(net)
    a = torch.rand(7, 3, 32, 32).to("cuda")
    print(net(a).shape)
