import math
import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision
from torchvision import transforms
from spikingjelly.activation_based import neuron, functional, surrogate, layer, base

from torch.utils.tensorboard import SummaryWriter
import os
import time
import argparse
from torch import amp
import sys
import datetime
from torch import Tensor
from torch.utils.data.dataloader import default_collate
from torchvision.transforms import autoaugment, transforms
from torchvision.transforms.functional import InterpolationMode
from torch.nn.modules.utils import _single
# Callable
from typing import Optional, List, Tuple, Union, Callable

_seed_ = 2020
import random
random.seed(2020)

torch.manual_seed(_seed_)  # use torch.manual_seed() to seed the RNG for all devices (both CPU and CUDA)
torch.cuda.manual_seed_all(_seed_)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

class ClassificationPresetTrain:
    def __init__(
        self,
        mean=(0.485, 0.456, 0.406),
        std=(0.229, 0.224, 0.225),
        interpolation=InterpolationMode.BILINEAR,
        hflip_prob=0.5,
        auto_augment_policy=None,
        random_erase_prob=0.0,
    ):
        trans = []
        if hflip_prob > 0:
            trans.append(transforms.RandomHorizontalFlip(hflip_prob))
        if auto_augment_policy is not None:
            if auto_augment_policy == "ra":
                trans.append(autoaugment.RandAugment(interpolation=interpolation))
            elif auto_augment_policy == "ta_wide":
                trans.append(autoaugment.TrivialAugmentWide(interpolation=interpolation))
            else:
                aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy)
                trans.append(autoaugment.AutoAugment(policy=aa_policy, interpolation=interpolation))
        trans.extend(
            [
                transforms.PILToTensor(),
                transforms.ConvertImageDtype(torch.float),
                transforms.Normalize(mean=mean, std=std),
            ]
        )
        if random_erase_prob > 0:
            trans.append(transforms.RandomErasing(p=random_erase_prob))

        self.transforms = transforms.Compose(trans)

    def __call__(self, img):
        return self.transforms(img)

from torch import Tensor
from typing import Tuple
class RandomMixup(torch.nn.Module):
    """Randomly apply Mixup to the provided batch and targets.
    The class implements the data augmentations as described in the paper
    `"mixup: Beyond Empirical Risk Minimization" <https://arxiv.org/abs/1710.09412>`_.

    Args:
        num_classes (int): number of classes used for one-hot encoding.
        p (float): probability of the batch being transformed. Default value is 0.5.
        alpha (float): hyperparameter of the Beta distribution used for mixup.
            Default value is 1.0.
        inplace (bool): boolean to make this transform inplace. Default set to False.
    """

    def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None:
        super().__init__()
        assert num_classes > 0, "Please provide a valid positive value for the num_classes."
        assert alpha > 0, "Alpha param can't be zero."

        self.num_classes = num_classes
        self.p = p
        self.alpha = alpha
        self.inplace = inplace

    def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
        """
        Args:
            batch (Tensor): Float tensor of size (B, C, H, W)
            target (Tensor): Integer tensor of size (B, )

        Returns:
            Tensor: Randomly transformed batch.
        """
        if batch.ndim != 4:
            raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}")
        if target.ndim != 1:
            raise ValueError(f"Target ndim should be 1. Got {target.ndim}")
        if not batch.is_floating_point():
            raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.")
        if target.dtype != torch.int64:
            raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}")

        if not self.inplace:
            batch = batch.clone()
            target = target.clone()

        if target.ndim == 1:
            target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype)

        if torch.rand(1).item() >= self.p:
            return batch, target

        # It's faster to roll the batch by one instead of shuffling it to create image pairs
        batch_rolled = batch.roll(1, 0)
        target_rolled = target.roll(1, 0)

        # Implemented as on mixup paper, page 3.
        lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0])
        batch_rolled.mul_(1.0 - lambda_param)
        batch.mul_(lambda_param).add_(batch_rolled)

        target_rolled.mul_(1.0 - lambda_param)
        target.mul_(lambda_param).add_(target_rolled)

        return batch, target

    def __repr__(self) -> str:
        s = (
            f"{self.__class__.__name__}("
            f"num_classes={self.num_classes}"
            f", p={self.p}"
            f", alpha={self.alpha}"
            f", inplace={self.inplace}"
            f")"
        )
        return s
class RandomCutmix(torch.nn.Module):
    """Randomly apply Cutmix to the provided batch and targets.
    The class implements the data augmentations as described in the paper
    `"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features"
    <https://arxiv.org/abs/1905.04899>`_.

    Args:
        num_classes (int): number of classes used for one-hot encoding.
        p (float): probability of the batch being transformed. Default value is 0.5.
        alpha (float): hyperparameter of the Beta distribution used for cutmix.
            Default value is 1.0.
        inplace (bool): boolean to make this transform inplace. Default set to False.
    """

    def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None:
        super().__init__()
        assert num_classes > 0, "Please provide a valid positive value for the num_classes."
        assert alpha > 0, "Alpha param can't be zero."

        self.num_classes = num_classes
        self.p = p
        self.alpha = alpha
        self.inplace = inplace

    def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
        """
        Args:
            batch (Tensor): Float tensor of size (B, C, H, W)
            target (Tensor): Integer tensor of size (B, )

        Returns:
            Tensor: Randomly transformed batch.
        """
        if batch.ndim != 4:
            raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}")
        if target.ndim != 1:
            raise ValueError(f"Target ndim should be 1. Got {target.ndim}")
        if not batch.is_floating_point():
            raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.")
        if target.dtype != torch.int64:
            raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}")

        if not self.inplace:
            batch = batch.clone()
            target = target.clone()

        if target.ndim == 1:
            target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype)

        if torch.rand(1).item() >= self.p:
            return batch, target

        # It's faster to roll the batch by one instead of shuffling it to create image pairs
        batch_rolled = batch.roll(1, 0)
        target_rolled = target.roll(1, 0)

        # Implemented as on cutmix paper, page 12 (with minor corrections on typos).
        lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0])
        W, H = torchvision.transforms.functional.get_image_size(batch)

        r_x = torch.randint(W, (1,))
        r_y = torch.randint(H, (1,))

        r = 0.5 * math.sqrt(1.0 - lambda_param)
        r_w_half = int(r * W)
        r_h_half = int(r * H)

        x1 = int(torch.clamp(r_x - r_w_half, min=0))
        y1 = int(torch.clamp(r_y - r_h_half, min=0))
        x2 = int(torch.clamp(r_x + r_w_half, max=W))
        y2 = int(torch.clamp(r_y + r_h_half, max=H))

        batch[:, :, y1:y2, x1:x2] = batch_rolled[:, :, y1:y2, x1:x2]
        lambda_param = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H))

        target_rolled.mul_(1.0 - lambda_param)
        target.mul_(lambda_param).add_(target_rolled)

        return batch, target

    def __repr__(self) -> str:
        s = (
            f"{self.__class__.__name__}("
            f"num_classes={self.num_classes}"
            f", p={self.p}"
            f", alpha={self.alpha}"
            f", inplace={self.inplace}"
            f")"
        )
        return s

class IFNode5(nn.Module):
    def __init__(self, T: int, surrogate_function=surrogate.ATan()):
        super().__init__()
        self.surrogate_function = surrogate_function
        self.fc = nn.Linear(T, T)
        nn.init.constant_(self.fc.bias, -1)

    def forward(self, x_seq: torch.Tensor):
        N, C, T = x_seq.shape
        x_seq = x_seq.permute(2, 0, 1) # [N, C, T] -> [T, N, C]
        h_seq = torch.addmm(self.fc.bias.unsqueeze(1), self.fc.weight, x_seq.flatten(1))
        spike = self.surrogate_function(h_seq)
        spike = spike.view(T, N, C).permute(1, 2, 0) # [T, NC] -> [T, N, C] -> [N, C, T]
        return spike
    


class SlidingPSN(base.MemoryModule):
    @property
    def supported_backends(self):
        return 'gemm', 'conv'

    def gen_gemm_weight(self, T: int):
        weight = torch.zeros([T, T], device=self.weight.device)
        for i in range(T):
            end = i + 1
            start = max(0, i + 1 - self.k)
            length = min(end - start, self.k)
            weight[i][start: end] = self.weight[self.k - length: self.k]

        return weight

    def __init__(self, k: int, exp_init: bool = True,
                 surrogate_function: surrogate.SurrogateFunctionBase = surrogate.ATan(), step_mode: str = 's',
                 backend: str = 'gemm'):
        """
        :param k: the order of the Sliding PSN
        :type k: int
        :param exp_init: if ``True``, the weight will be initialized as ``(..., 1/4, 1/2, 1)``. If ``False``, the weight    will be initialized by the kaiming uniform
        :type exp_init: bool
        :param surrogate_function: the function for calculating surrogate gradients of the heaviside step function in backward
        :type surrogate_function: Callable
        :param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step)
        :type step_mode: str
        :param backend: backend fot this neuron layer, which can be "gemm" or "conv". This option only works for the multi-step mode
        :type backend: str

        The Sliding Parallel Spiking Neuron proposed in `Parallel Spiking Neurons with High Efficiency and Long-term Dependencies Learning Ability <https://arxiv.org/abs/2304.12760>`_. The neuronal dynamics are defined as

        .. math::

            H[t] &= \\sum_{i=0}^{k-1}W_{i}\\cdot X[t - k + 1 + i], \\\\
	        S[t] &= \\Theta(H[t] - B),


        where :math:`W = [W_{0}, W_{1}, ..., W_{k-1}] \\in \\mathbb{R}^{T}` is the learnable weight, and :math:`B` is the learnable threshold.


        .. admonition:: Note
            :class: note

            The Sliding PSN supports both single-step and multi-step mode. But using the multi-step mode is much faster than the single-step mode.


        """

        super().__init__()
        self.register_memory('queue', [])
        self.step_mode = step_mode
        self.k = k
        self.surrogate_function = surrogate_function
        self.backend = backend

        if exp_init:
            weight = torch.ones([k])
            for i in range(k - 2, -1, -1):
                weight[i] = weight[i + 1] / 2.
        else:
            weight = torch.ones([1, k])
            nn.init.kaiming_uniform_(weight, a=math.sqrt(5))
            weight = weight[0]

        self.weight = nn.Parameter(weight)
        self.bias = nn.Parameter(torch.as_tensor(-1.))

    def single_step_forward(self, x: torch.Tensor):
        self.queue.append(x.flatten())
        if self.queue.__len__() > self.k:
            self.queue.pop(0)

        weight = self.weight[self.k - self.queue.__len__(): self.k]
        x_seq = torch.stack(self.queue)

        weight = weight.unsqueeze(-1)

        h = torch.sum(weight * x_seq, 0)
        spike = self.surrogate_function(h + self.bias)

        return spike.view(x.shape)

    def multi_step_forward(self, x_seq: torch.Tensor):
        # [N, C, T]
        if self.backend == 'gemm':
            weight = self.gen_gemm_weight(x_seq.shape[0])
            h_seq = torch.addmm(self.bias, weight, x_seq.flatten(1)).view(x_seq.shape)
            return self.surrogate_function(h_seq)
        elif self.backend == 'conv':

            # x_seq.shape = [N, C, T]
            x_seq_shape = x_seq.shape
            # [T, N, *] -> [T, N] -> [N, T] -> [N, 1, T]
            # [N, C, T] -> [NC, T] -> [NC, 1 ,T]
            # x_seq = x_seq.flatten(1).t().unsqueeze(1)
            x_seq = x_seq.flatten(0, 1).unsqueeze(1)

            x_seq = F.pad(x_seq, pad=(self.k - 1, 0))
            x_seq = F.conv1d(x_seq, self.weight.view(1, 1, -1), stride=1)

            x_seq = x_seq.squeeze(1).view(x_seq_shape)
            return self.surrogate_function(x_seq + self.bias)
        else:
            raise NotImplementedError(self.backend)

    def extra_repr(self):
        return super().extra_repr() + f', order={self.k}'

@torch.jit.script
def round_to_pow2(x: torch.Tensor):
    sign = torch.sign(x)

    p = torch.log2(x.abs())
    p_round = torch.round(p)
    return sign * torch.pow(2., p_round)

class RoundToPow2(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x: torch.Tensor):
        return round_to_pow2(x)

    @staticmethod
    def backward(ctx, grad_output: torch.Tensor):
        return grad_output

def roundtopow(x: torch.Tensor):
    return RoundToPow2.apply(x)

class BN1d(nn.Module):
    def __init__(
        self,
        num_features: int,
        eps: float = 1e-5,
        momentum: float = 0.1,
        device=None,
        dtype=None,
    ) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        self.weight = nn.Parameter(torch.ones(num_features, **factory_kwargs))
        self.bias = nn.Parameter(torch.zeros(num_features, **factory_kwargs))
        self.register_buffer('running_mean', torch.zeros(num_features, **factory_kwargs))
        self.register_buffer('running_var', torch.ones(num_features, **factory_kwargs))

    def get_mean_var(self, input: Tensor=None) -> Tuple[Tensor, Tensor]:
        if self.training:
            mean = input.mean([0, 2])
            var = input.var([0, 2], unbiased=False)
            with torch.no_grad():
                self.running_mean = (1. - self.momentum) * self.running_mean + self.momentum * mean
                self.running_var = (1. - self.momentum) * self.running_var + self.momentum * input.var([0, 2], unbiased=True)
        else:
            mean, var = self.running_mean, self.running_var

        return mean, var


class mul_free_channel_wise_psn(nn.Module):
    def __init__(self, C: int, K: int, surrogate_function, dilation):
        super().__init__()
        self.register_memory('queue', [])
        self.C = C
        self.K = K
        self.surrogate_function = surrogate_function
        self.conv1d = nn.Conv1d(in_channels=C, out_channels=C, kernel_size=K, groups=C, dilation=dilation, bias=False)
        self.dilation = dilation
        self.bn1d = BN1d(C)
        nn.init.constant_(self.bn1d.bias, -1.)

    def forward(self, x_seq: torch.Tensor):
        # N, C, T
        x_pad = F.pad(x_seq, (self.dilation * (self.K - 1), 0))

        if self.training:
            t = self.conv1d(x_pad)
            mean, var = self.bn1d.get_mean_var(t)
        else:
            mean, var = self.bn1d.get_mean_var()

        weight = (self.conv1d.weight.transpose(0, 2) * self.bn1d.weight / (var + self.bn1d.eps).sqrt()).transpose(0, 2)
        bias = self.bn1d.bias - mean * self.bn1d.weight / (var + self.bn1d.eps).sqrt()
        qweight = roundtopow(weight)
        h_seq = F.conv1d(x_pad, qweight, bias=bias, stride=self.conv1d.stride, padding=self.conv1d.padding, dilation=self.conv1d.dilation, groups=self.conv1d.groups)
        
        s_seq = self.surrogate_function(h_seq)

        return s_seq


class RB(nn.Module):
    def __init__(self, T, C, sn='psn', k:int=None, d1:int=1, d2:int=1):
        super(RB, self).__init__()
        self.conv1 = nn.Conv1d(in_channels=C, out_channels=C, kernel_size=1)
        self.bn1 = nn.BatchNorm1d(C)
        if sn == 'PSN':
            self.sn1 = IFNode5(T=T, surrogate_function=surrogate.ATan())
            self.sn2 = IFNode5(T=T, surrogate_function=surrogate.ATan())
        elif sn == 'SlidingPSN':
            self.sn1 = SlidingPSN(k=k, step_mode='m', backend='conv')
            self.sn2 = SlidingPSN(k=k, step_mode='m', backend='conv')
        elif sn == "mul_free_channel_wise_psn":
            self.sn1 = mul_free_channel_wise_psn(C=C, K=k, surrogate_function=surrogate.ATan(), dilation=d1, step_mode='m', backend='torch')
            self.sn2 = mul_free_channel_wise_psn(C=C, K=k, surrogate_function=surrogate.ATan(), dilation=d2, step_mode='m', backend='torch')

        self.conv2 = nn.Conv1d(in_channels=C, out_channels=C, kernel_size=1)
        self.bn2 = nn.BatchNorm1d(C)

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.sn1(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.sn2(out)

        out += residual

        return out


class pixelCIFAR100Net(nn.Module):
    def __init__(self, num_classes, K:int=8, sn='PSN'):
        super().__init__()
        if sn == "PSN":
            sn1 =  IFNode5(T=1024, surrogate_function=surrogate.ATan())
            sn2 =  IFNode5(T=256, surrogate_function=surrogate.ATan())
        elif sn == "SlidingPSN":
            sn1 = SlidingPSN(k=K, step_mode='m', backend='conv')
            sn2 = SlidingPSN(k=K, step_mode='m', backend='conv')
        elif sn == "mul_free_channel_wise_psn":
            sn1 = mul_free_channel_wise_psn(C=128, K=K, surrogate_function=surrogate.ATan(), dilation=1, step_mode='m', backend='torch')
            sn2 = mul_free_channel_wise_psn(C=256, K=K, surrogate_function=surrogate.ATan(), dilation=1, step_mode='m', backend='torch')

        self.fc1 = layer.TLastSeqToANNContainer(nn.Linear(3, 128))
        self.rb1 = nn.Sequential(
            nn.BatchNorm1d(128),
            sn1,
            RB(T=1024, C=128, sn=sn, k=K, d1=1, d2=1),
            RB(T=1024, C=128, sn=sn, k=K, d1=1, d2=1),
            nn.AvgPool1d(4),
        )
        self.fc2 =  layer.TLastSeqToANNContainer(nn.Linear(128, 256))
        self.rb2 = nn.Sequential(
            nn.BatchNorm1d(256),  
            sn2,
            RB(T=256, C=256, sn=sn, k=K, d1=1, d2=1),
            RB(T=256, C=256, sn=sn, k=K, d1=1, d2=1),
        )
        self.classifier = layer.TLastSeqToANNContainer(nn.Linear(256, num_classes))
    
    def forward(self, x):
        x = x.flatten(2) # [N, C, T]
        x = self.fc1(x)
        x = self.rb1(x)
        x = self.fc2(x)
        x = self.rb2(x)
        x = self.classifier(x)
        return x.mean(-1)


def main():
    '''
        python pixel_cf10_d1.py -data-dir /datasets/CIFAR10 -amp -opt adamw -epochs 128 -class-num 10 -device cuda:0 -K 32 -lr 0.001 -sn mul_free_channel_wise_psn
    '''
    parser = argparse.ArgumentParser(description='Classify Sequential CIFAR10/100')
    parser.add_argument('-device', default='cuda:0', help='device')
    parser.add_argument('-b', default=128, type=int, help='batch size')
    parser.add_argument('-epochs', default=64, type=int, metavar='N',
                        help='number of total epochs to run')
    parser.add_argument('-j', default=4, type=int, metavar='N',
                        help='number of data loading workers (default: 4)')
    parser.add_argument('-data-dir', type=str, help='root dir of CIFAR10/100 dataset')
    parser.add_argument('-out-dir', type=str, default='./logs', help='root dir for saving logs and checkpoint')
    parser.add_argument('-resume', type=str, help='resume from the checkpoint path')
    parser.add_argument('-amp', action='store_true', help='automatic mixed precision training')
    parser.add_argument('-opt', type=str, help='use which optimizer. SDG or Adam')
    parser.add_argument('-momentum', default=0.9, type=float, help='momentum for SGD')
    parser.add_argument('-lr', default=0.1, type=float, help='learning rate')
    parser.add_argument('-class-num', type=int, default=10)
    parser.add_argument('-K', type=int, default=16, help='the kernel size of the sliding PSN')
    parser.add_argument('-sn', type=str, default='psn', help='which surrogate function to use. psn , SlidingPSN, mul_free_channel_wise_psn' )
    args = parser.parse_args()
    print(args)

    mixup_transforms = []
    mixup_transforms.append(RandomMixup(args.class_num, p=1.0, alpha=0.2)) #
    mixup_transforms.append(RandomCutmix(args.class_num, p=1.0, alpha=1.))
    mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms)
    collate_fn = lambda batch: mixupcutmix(*default_collate(batch))  # noqa: E731

    if args.class_num == 10:
        transform_train = ClassificationPresetTrain(mean=(0.4914, 0.4822, 0.4465),
                                                      std=(0.2023, 0.1994, 0.2010), interpolation=InterpolationMode('bilinear'),
                                                      auto_augment_policy='ta_wide',
                                                      random_erase_prob=0.1)
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
    elif args.class_num == 100:
        transform_train = ClassificationPresetTrain(mean=(0.5070751592371323, 0.48654887331495095, 0.4409178433670343),
                                                    std=(0.2673342858792401, 0.2564384629170883, 0.27615047132568404),
                                                    interpolation=InterpolationMode('bilinear'),
                                                    auto_augment_policy='ta_wide',
                                                    random_erase_prob=0.1)
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5070751592371323, 0.48654887331495095, 0.4409178433670343), (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)),
        ])
    else:
        raise NotImplementedError(args.class_num)


    if args.class_num == 10:
        train_set = torchvision.datasets.CIFAR10(
                root=args.data_dir,
                train=True,
                transform=transform_train,
                download=True)

        test_set = torchvision.datasets.CIFAR10(
                root=args.data_dir,
                train=False,
                transform=transform_test,
                download=True)

    elif args.class_num == 100:
        train_set = torchvision.datasets.CIFAR100(
            root=args.data_dir,
            train=True,
            transform=transform_train,
            download=True)

        test_set = torchvision.datasets.CIFAR100(
            root=args.data_dir,
            train=False,
            transform=transform_test,
            download=True)
    else:
        raise NotImplementedError(args.class_num)

    train_data_loader = torch.utils.data.DataLoader(
        dataset=train_set,
        batch_size=args.b,
        collate_fn=collate_fn,
        shuffle=True,
        drop_last=True,
        num_workers=args.j,
        pin_memory=True
    )

    test_data_loader = torch.utils.data.DataLoader(
        dataset=test_set,
        batch_size=args.b,
        shuffle=False,
        drop_last=True,
        num_workers=args.j,
        pin_memory=True
    )

    out_dir = f'logs/mul_free_channel_wise_psn_d1/sn_{args.sn}_e{args.epochs}_b{args.b}_{args.opt}_lr{args.lr}'

    if args.sn == 'SlidingPSN' or args.sn == 'mul_free_channel_wise_psn':
        out_dir += f'_K{args.K}'

    if args.amp:
        out_dir += '_amp'


    pt_dir = os.path.join(args.out_dir, 'pt', out_dir)
    out_dir = os.path.join(args.out_dir, out_dir)

    if not os.path.exists(pt_dir):
        os.makedirs(pt_dir)

    net = pixelCIFAR100Net(args.class_num, K=args.K, sn=args.sn)
    net.to(args.device)

    total_params = sum(p.numel() for p in net.parameters())
    print(f"Total params: {total_params / 1e6} M")


    scaler = None
    if args.amp:
        scaler = amp.GradScaler()

    start_epoch = 0
    max_test_acc = -1

    optimizer = None
    if args.opt == 'sgd':
        optimizer = torch.optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum)
    elif args.opt == 'adamw':
        optimizer = torch.optim.AdamW(net.parameters(), lr=args.lr, weight_decay=0.)
    else:
        raise NotImplementedError(args.opt)

    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs)

    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu')
        net.load_state_dict(checkpoint['net'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        start_epoch = checkpoint['epoch'] + 1
        max_test_acc = checkpoint['max_test_acc']
        print(max_test_acc)

    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
        print(f'Mkdir {out_dir}.')

    writer = SummaryWriter(out_dir, purge_step=start_epoch)

    with open(os.path.join(out_dir, 'args.txt'), 'w', encoding='utf-8') as args_txt:
        args_txt.write(str(args))
        args_txt.write('\n')
        args_txt.write(' '.join(sys.argv))

    for epoch in range(start_epoch, args.epochs):
        start_time = time.time()
        net.train()
        train_loss = 0
        train_acc = 0
        train_samples = 0

        for batch_index, (img, label) in enumerate(train_data_loader):
            optimizer.zero_grad()
            img = img.to(args.device, non_blocking=True)
            label = label.to(args.device, non_blocking=True)


            with torch.amp.autocast(device_type=args.device, enabled=scaler is not None):
                y = net(img)
                loss = F.cross_entropy(y, label, label_smoothing=0.1)

            if scaler is not None:
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                loss.backward()
                optimizer.step()

            train_samples += label.shape[0]
            train_loss += loss.item() * label.shape[0]
            train_acc += (y.argmax(1) == label.argmax(1)).float().sum().item()

            functional.reset_net(net)

        train_time = time.time()
        train_speed = train_samples / (train_time - start_time)
        train_loss /= train_samples
        train_acc /= train_samples

        writer.add_scalar('train_loss', train_loss, epoch)
        writer.add_scalar('train_acc', train_acc, epoch)
        lr_scheduler.step()

        net.eval()
        test_loss = 0
        test_acc = 0
        test_samples = 0
        with torch.no_grad():
            for img, label in test_data_loader:
                img = img.to(args.device)
                label = label.to(args.device)
                y = net(img)
                loss = F.cross_entropy(y, label)
                test_samples += label.numel()
                test_loss += loss.item() * label.numel()
                test_acc += (y.argmax(1) == label).float().sum().item()
                functional.reset_net(net)
        test_time = time.time()
        test_speed = test_samples / (test_time - train_time)
        test_loss /= test_samples
        test_acc /= test_samples
        writer.add_scalar('test_loss', test_loss, epoch)
        writer.add_scalar('test_acc', test_acc, epoch)

        save_max = False
        if test_acc > max_test_acc:
            max_test_acc = test_acc
            save_max = True

        checkpoint = {
            'net': net.state_dict(),
            'optimizer': optimizer.state_dict(),
            'lr_scheduler': lr_scheduler.state_dict(),
            'epoch': epoch,
            'max_test_acc': max_test_acc
        }

        if save_max:
            torch.save(checkpoint, os.path.join(pt_dir, 'checkpoint_max.pth'))

        torch.save(checkpoint, os.path.join(pt_dir, 'checkpoint_latest.pth'))

        print(args)
        print(out_dir)
        print(f'epoch = {epoch}, train_loss ={train_loss: .4f}, train_acc ={train_acc: .4f}, test_loss ={test_loss: .4f}, test_acc ={test_acc: .4f}, max_test_acc ={max_test_acc: .4f}')
        print(f'train speed ={train_speed: .4f} images/s, test speed ={test_speed: .4f} images/s')
        print(f'escape time = {(datetime.datetime.now() + datetime.timedelta(seconds=(time.time() - start_time) * (args.epochs - epoch))).strftime("%Y-%m-%d %H:%M:%S")}\n')


if __name__ == '__main__':
    main()