import argparse
import math
import os
import time
import numpy as np
from datetime import datetime
import logging
import tensorboard_logger as tb_logger
import pprint
import copy

import torch
import torch.nn.parallel
import torch.nn.functional as F
import torch.optim
import torch.utils.data
from torchvision import transforms
from models.hyptorch import expmap0
import geoopt.optim.rsgd as rsgd_
import geoopt.optim.radam as radam_
from pytorch_optimizer import AdamP# , SAM 
from pytorch_optimizer import GSAM
from pytorch_optimizer import LinearScheduler, ProportionScheduler

from optim import SAM, RSAM, RSAM_OCNN
from optim import Retraction, Parallel_Transport, Projection
from utils import (CompLoss, DisLoss, DisLPLoss, RegLoss, SupConLoss, SphereFace2,
                SphereFaceR_N, SphereFaceR_H, SphereFaceR_S, PeBusePenalty,
                AverageMeter, adjust_learning_rate, warmup_learning_rate, 
                set_loader_small, set_loader_ImageNet, set_model)
from utils import enable_running_stats, disable_running_stats
from utils.attack import Attack

parser = argparse.ArgumentParser(description='Training with CIDER and SupCon Loss')
parser.add_argument('--gpu', default=1, type=int, help='which GPU to use')
parser.add_argument('--seed', default=4, type=int, help='random seed')
parser.add_argument('--w', default=1, type=float,
                    help='loss scale')
parser.add_argument('--wr', default=1, type=float,
                    help='regularization loss scale')
parser.add_argument('--proto_m', default= 0.5, type=float,
                   help='weight of prototype update')
parser.add_argument('--feat_dim', default = 128, type=int,
                    help='feature dim')
parser.add_argument('--in-dataset', default="CIFAR-100", type=str, help='in-distribution dataset')
parser.add_argument('--id_loc', default="datasets/CIFAR100", type=str, help='location of in-distribution dataset')
parser.add_argument('--model', default='resnet34', type=str, help='model architecture: [resnet18, resnet34]')
parser.add_argument('--head', default='mlp', type=str, help='either mlp or linear head')
parser.add_argument('--loss', default = 'mgp', type=str, choices = ['cider_ce', 'mgp_ce', 'ce', 'supcon', 'cider', 'sf', 'sfrn', 'sfrh', 'sfrs', 'hypb', 'mgp'],
                    help='train loss')
parser.add_argument('--epochs', default=500, type=int,
                    help='number of total epochs to run')
parser.add_argument('--trial', type=str, default='0',
                        help='id for recording multiple runs')
parser.add_argument('--save-epoch', default=100, type=int,
                    help='save the model every save_epoch')
parser.add_argument('--start-epoch', default=0, type=int,
                    help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batch-size', default= 512, type=int,
                    help='mini-batch size (default: 64)')
parser.add_argument('--learning_rate', default=0.5, type=float,
                    help='initial learning rate')
# if linear lr schedule
parser.add_argument('--lr_decay_epochs', type=str, default='100,150,180',
                        help='where to decay lr, can be a list')
parser.add_argument('--lr_decay_rate', type=float, default=0.1,
                        help='decay rate for learning rate')
# if cosine lr schedule
parser.add_argument('--cosine', action='store_true',
                        help='using cosine annealing')
parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
                    help='weight decay (default: 0.0001)')
parser.add_argument('--print-freq', '-p', default=10, type=int,
                    help='print frequency (default: 10)')
parser.add_argument('--temp', type=float, default=0.1,
                        help='temperature for loss function')
parser.add_argument('--warm', action='store_true',
                        help='warm-up for large batch training')
parser.add_argument('--normalize', action='store_true',
                        help='normalize feat embeddings')
parser.add_argument('--subset', default=False,
                        help='whether to use subset of training set to init prototypes')
parser.add_argument('--main_dir', default="./",
                        help='working space')
parser.add_argument('--ash_method', default="",
                        help='apply ash layer with percentage')
parser.add_argument('--c_ball', type=float, default=1.0,
                        help='curvature of the Poincare ball')
parser.add_argument('--train_origin', type=bool, default=False,
                        help='train origin of the Poincare ball')
parser.add_argument('--train_c', type=bool, default=False,
                        help='train curative of the Poincare ball')
parser.add_argument('--r', type=int, default=1,
                        help='radius of disparity loss')
parser.add_argument('--margin', type=int, default=0.3,
                        help='margin of disparity loss')
parser.add_argument('--optimizer', type=str, default="sgd",
                        help='optimizer name')
parser.add_argument('--attack', type=str, default="",
                        help='optimizer name')
parser.set_defaults(bottleneck=True)
parser.set_defaults(augment=True)

args = parser.parse_args()

state = {k: v for k, v in args._get_kwargs()}

date_time = datetime.now().strftime("%d_%m_%H:%M")

#processing str to list for linear lr scheduling
args.lr_decay_epochs = [int(step) for step in args.lr_decay_epochs.split(',')]

if args.loss == 'supcon':
    args.name = date_time + "_" + 'supcon_{}_lr_{}_cosine_{}_bsz_{}_{}_{}_{}_trial_{}_temp_{}_{}_{}'.\
        format(args.model, args.learning_rate, args.cosine,
               args.batch_size, args.loss, args.epochs, args.feat_dim, args.trial, args.temp, args.in_dataset, args.head)
elif 'cider' in args.loss or 'mgp' in args.loss: 
    args.name = (f"{date_time}_{args.loss}_{args.model}_lr_{args.learning_rate}_cosine_"
        f"{args.cosine}_bsz_{args.batch_size}_{args.loss}_wd_{args.w}_{args.epochs}_{args.feat_dim}_"
        f"trial_{args.trial}_temp_{args.temp}_{args.in_dataset}_pm_{args.proto_m}")
elif "sf" in args.loss:
    args.name = date_time + "_" + 'sf_{}_lr_{}_cosine_{}_bsz_{}_{}_{}_{}_trial_{}_temp_{}_{}_{}'.\
        format(args.model, args.learning_rate, args.cosine,
               args.batch_size, args.loss, args.epochs, args.feat_dim, 
               args.trial, args.temp, args.in_dataset, args.head
              )
else:
    args.name = date_time + "_" + 'ce_{}_lr_{}_cosine_{}_bsz_{}_{}_{}_{}_trial_{}_temp_{}_{}_{}'.\
        format(args.model, args.learning_rate, args.cosine,
               args.batch_size, args.loss, args.epochs, args.feat_dim, 
               args.trial, args.temp, args.in_dataset, args.head
              )
args.log_directory = f"{args.main_dir}/logs/{args.in_dataset}/{args.name}/"
args.model_directory = f"{args.main_dir}/checkpoints/{args.in_dataset}/{args.name}/"
args.tb_path = f'{args.main_dir}/save/sagd/{args.in_dataset}_tensorboard'
os.makedirs(args.model_directory, exist_ok=True)
os.makedirs(args.log_directory, exist_ok=True)
args.tb_folder = os.path.join(args.tb_path, args.name)
os.makedirs(args.tb_folder, exist_ok=True)

#save args
with open(os.path.join(args.log_directory, 'train_args.txt'), 'w') as f:
    f.write(pprint.pformat(state))

#init log
log = logging.getLogger(__name__)
formatter = logging.Formatter('%(asctime)s : %(message)s')
fileHandler = logging.FileHandler(os.path.join(args.log_directory, "train_info.log"), mode='w')
fileHandler.setFormatter(formatter)
streamHandler = logging.StreamHandler()
streamHandler.setFormatter(formatter)
log.setLevel(logging.DEBUG)
log.addHandler(fileHandler)
log.addHandler(streamHandler) 

log.debug(state)

if args.in_dataset == "CIFAR-10":
    args.n_cls = 10
elif args.in_dataset in ["CIFAR-100", "ImageNet-100"]:
    args.n_cls = 100


#set seeds
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
np.random.seed(args.seed)
log.debug(f"{args.name}")

# warm-up for large-batch training
if args.batch_size > 256:
    args.warm = True
if args.warm:
    args.warmup_from = 0.001
    args.warm_epochs = 10
    if args.cosine:
        eta_min = args.learning_rate * (args.lr_decay_rate ** 3)
        args.warmup_to = eta_min + (args.learning_rate - eta_min) * (
                1 + math.cos(math.pi * args.warm_epochs / args.epochs)) / 2
    else:
        args.warmup_to = args.learning_rate


def main():
    tb_log = tb_logger.Logger(logdir=args.tb_folder, flush_secs=2)

    if args.in_dataset == "ImageNet-100":
        train_loader, val_loader = set_loader_ImageNet(args)
        aux_loader, _  = set_loader_ImageNet(args, eval=True)
    else:
        train_loader, val_loader = set_loader_small(args)
        aux_loader, _ = set_loader_small(args, eval=True)


    model = set_model(args)
    criterions = set_criterions(args, model, aux_loader) 
    (criterion_ce, criterion_supcon, criterion_comp, 
        criterion_dis, criterion_reg, criterion_sf, criterion_hypb
    ) = criterions 
    
    ### pytorch-optimizer
    # opt = load_optimizer(optimizer='sam')
    # optimizer = opt(model.parameters())
    
    
    choosen_layer =  model.encoder.layer4[0].conv2.weight
    optimizers = []

    if args.optimizer == "sam":
        base_optimizer = torch.optim.SGD
        optimizer = SAM(model.parameters(), base_optimizer, 
                        rho=1.0,
                        adaptive=True,
                        lr=args.learning_rate,
                        # momentum=args.momentum,
                        weight_decay=args.weight_decay
                       ) 
    elif args.optimizer == "gsam":
        rho_max = 2.0 
        rho_min = 2.0 

        rho_scheduler = LinearScheduler(T_max=args.epochs*len(dataset.train), 
            max_value=rho_max, 
            min_value=rho_min
        )
        max_lr = 0.8 
        num_total_steps = args.epochs*args.batch_size
        base_optimizer = torch.optim.SGD(model.parameters(), 
                                         lr=args.learning_rate,
                                         momentum=args.momentum,
                                         nesterov=True,
                                         weight_decay=args.weight_decay
                                        )
        lr_scheduler = LinearScheduler(base_optimizer, t_max=num_total_steps)
        rho_scheduler = ProportionScheduler(lr_scheduler, max_lr=max_lr)
        optimizer = GSAM(model.parameters(), base_optimizer, 
                         gsam_alpha=args.alpha, 
                         rho_scheduler=rho_scheduler, 
                         adaptive=True
                        ) 
    elif args.optimizer == "rsam":
        num_filters, filters_depth, filters_height, filters_width = choosen_layer.size(dim = 0), choosen_layer.size(dim = 1), choosen_layer.size(dim = 2), choosen_layer.size(dim = 3)
        w = torch.nn.init.orthogonal_(
            torch.empty(num_filters, filters_width*filters_depth*filters_height)
        )
        new_param = torch.reshape(w, (num_filters, filters_depth, filters_height, filters_width))
        choosen_layer = copy.deepcopy(new_param)
        
        manifold = "sphere"
        param_ocnn = []
        param_ocnn.append({
            'params': choosen_layer,
            'lr': args.learning_rate,
            'weight_decay': args.weight_decay,
            'manifold': manifold,
            'proj': Projection(manifold=manifold),
            'retr': Retraction(manifold=manifold),
            'transp': Parallel_Transport(manifold=manifold)
        })
        optimizer_ocnn = RSAM_OCNN(param_ocnn)

        base_optimizer = torch.optim.SGD
        optimizer = SAM(model.parameters(), 
            base_optimizer, 
            rho=1.0,
            adaptive=True,
            lr=args.learning_rate,
            momentum=args.momentum,
            weight_decay=args.weight_decay
        ) 
        print("="*30, "RSAM set")
        optimizers = [optimizer, optimizer_ocnn]

    elif args.optimizer == "rsgd":
        optimizer = rsgd_.RiemannianSGD(model.parameters(), lr=args.learning_rate, 
                                        momentum=args.momentum,
                                        nesterov=True,
                                        weight_decay=args.weight_decay
                                       )
    else:
        optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate,
                                    momentum=args.momentum,
                                    nesterov=True,
                                    weight_decay=args.weight_decay
                                   )

    if not len(optimizers):
        optimizers = [optimizer]


    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(args, optimizer, epoch)
        ## train for one epoch
        # criterions = (criterion_ce, criterion_supcon, criterion_comp, 
        #     criterion_dis, criterion_reg, criterion_sf, criterion_hypb
        # )
        train_losses = trainer(args, 
            train_loader, model, criterions, optimizers, epoch, log, choosen_layer=choosen_layer
        )
        (train_sloss, train_uloss, train_dloss, train_rloss, 
            train_sfloss, train_celoss, train_hypbloss
        ) = train_losses

        train_loss_dict = {"supcon": train_sloss, 
            "sf": train_sfloss, "sfrn": train_sfloss, 
            "sfrh": train_sfloss, "sfrs": train_sfloss, 
            "ce": train_celoss, "hypb": train_hypbloss
        }
        # if args.loss == 'cider_ce':
        #     tb_log.log_value('train_uni_loss', train_uloss, epoch)
        #     tb_log.log_value('train_dis_loss', train_dloss, epoch)
        if 'cider' in args.loss:
            tb_log.log_value('train_uni_loss', train_uloss, epoch)
            tb_log.log_value('train_dis_loss', train_dloss, epoch)
        elif 'mgp' in args.loss:
            tb_log.log_value('train_reg_loss', train_rloss, epoch)
            tb_log.log_value('train_uni_loss', train_uloss, epoch)
            tb_log.log_value('train_dis_loss', train_dloss, epoch)
        else:
            tb_log.log_value(f'train_{args.loss}_loss', 
                train_loss_dict[args.loss], epoch
            )

        if "ce" in args.loss:
            tb_log.log_value('train_ce_loss', train_celoss, epoch)


        # tensorboard logger
        tb_log.log_value('learning_rate', 
            optimizer.param_groups[0]['lr'], epoch
        )
        # save checkpoint
        criterion_dict = {"supcon": criterion_supcon, 
            "sf": criterion_sf, "sfrn": criterion_sf, 
            "sfrh": criterion_sf, "sfrs": criterion_sf, 
            "ce": criterion_ce, "hypb": criterion_hypb
        }
        template = {'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'opt_state_dict': optimizer.state_dict(),
                   }


        if (epoch + 1) % args.save_epoch == 0: 
            if 'cider' in args.loss:
                template.update({
                    'dis_state_dict': criterion_dis.state_dict(),
                    'uni_state_dict': criterion_comp.state_dict(),
                })
            elif 'mgp' in args.loss:
                template.update({
                    'dis_state_dict': criterion_dis.state_dict(),
                    'uni_state_dict': criterion_comp.state_dict(),
                    'reg_state_dict': criterion_reg.state_dict(),
                })
            else:
                template.update(
                    {f'{args.loss}_state_dict': 
                     criterion_dict[args.loss].state_dict()
                    }
                )

            # bug might be here
            if '_ce' in args.loss:
                template.update(
                    {f'ce_state_dict': criterion_ce.state_dict()}
                )
                # save_checkpoint(args, {
                #     'epoch': epoch + 1,
                #     'state_dict': model.state_dict(),
                #     'opt_state_dict': optimizer.state_dict(),
                #     f'{args.loss}_state_dict': criterion_dict[args.loss].state_dict(),
                # }, epoch + 1)

            save_checkpoint(args, template, epoch + 1)


def set_criterions(args, model, aux_loader):
    criterion_ce = torch.nn.CrossEntropyLoss() 
    criterion_supcon = SupConLoss(args).cuda()
    criterion_comp = CompLoss(args, temperature=args.temp).cuda()
    if args.loss == 'sfrn':
        criterion_sf = SphereFaceR_N(args).cuda()
    elif args.loss == 'sfrh':
        criterion_sf = SphereFaceR_H(args).cuda()
    elif args.loss == 'sfrs':
        criterion_sf = SphereFaceR_S(args).cuda()
    else:
        criterion_sf = SphereFace2(args).cuda()
    criterion_hypb = PeBusePenalty(args).cuda()

    # V1: learnable prototypes
    # criterion_dis = DisLPLoss(args, model, val_loader, temperature=args.temp).cuda() # V1: learnable prototypes
    # V2: EMA style prototypes
    criterion_dis = DisLoss(args, model, aux_loader, temperature=args.temp).cuda() # V2: prototypes with EMA style update
    criterion_reg = RegLoss(args).cuda() 

    criterions = [criterion_ce, criterion_supcon, criterion_comp, 
        criterion_dis, criterion_reg, criterion_sf, criterion_hypb
    ] 
    
    # dcgan
    # criterion_gan = nn.BCELoss()
    # criterion_gan_mse = nn.MSELoss() 
    # criterions_atk = [criterion_gan, criterion_gan_mse]  
    return criterions


def get_losses(args, features, target, pred, criterions, all_losses):
    (criterion_ce, criterion_supcon, criterion_comp, criterion_dis, 
        criterion_reg, criterion_sf, criterion_hypb
    ) = criterions
    (ce_losses, supcon_losses, comp_losses, dis_losses, 
        reg_losses, sf_losses, hypb_losses 
    ) = all_losses
    if 'cider' in args.loss:
        # dis_loss = criterion_dis.compute() # V1: learnable prototypes
        dis_loss = criterion_dis(features, target) # V2: EMA style
        comp_loss = criterion_comp(features, criterion_dis.prototypes, target)
        loss = args.w * comp_loss + dis_loss
        dis_losses.update(dis_loss.data, features.size(0))
        comp_losses.update(comp_loss.data, features.size(0))
        losses = {"dis": dis_losses, "comp": comp_losses}
    elif 'mgp' in args.loss:
        features1, features2 = features
        bsz = target.shape[0]//2
        f1, f2 = torch.split(features2, [bsz, bsz], dim=0) #f1 shape: [bz, feat_dim]
        features2 = torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1) #features shape: [bz, 2, feat_dim]
        #reg_loss = criterion_reg(features, target, criterion_dis.prototypes) 
        reg_loss = criterion_supcon(features2, target[:bsz])
        # dis_loss = criterion_dis.compute() # V1: learnable prototypes
        dis_loss = criterion_dis(features1, target) # V2: EMA style
        comp_loss = criterion_comp(features1, criterion_dis.prototypes, target)
        #loss = torch.clamp(args.w*comp_loss + dis_loss + args.wr*reg_loss, min=0)
        loss = args.w*comp_loss + dis_loss + args.wr*reg_loss
        dis_losses.update(dis_loss.data, features1.size(0))
        reg_losses.update(reg_loss.data, features1.size(0))
        comp_losses.update(comp_loss.data, features1.size(0))
        losses = {"dis": dis_losses, "comp": comp_losses, "reg": reg_losses}

    elif 'supcon' in args.loss:
        bsz = target.shape[0]//2
        f1, f2 = torch.split(features, [bsz, bsz], dim=0) #f1 shape: [bz, feat_dim]
        features = torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1) #features shape: [bz, 2, feat_dim]
        supcon_loss = criterion_supcon(features, target[:bsz])
        supcon_losses.update(supcon_loss.data, features.size(0))
        loss = supcon_loss
        losses = {"supcon": supcon_losses}
    elif 'sf' in args.loss:
        sf_loss = criterion_sf(features, target)
        sf_losses.update(sf_loss.data, features.size(0))
        loss = sf_loss 
        losses = {"sf": sf_losses}
    elif "hypb" in args.loss:
        features = expmap0(features, c=args.c_ball)
        target = criterion_hypb.prototypes[target]
        target = torch.autograd.Variable(target)

        hypb_loss = criterion_hypb(features, target)
        loss = hypb_loss 
        hypb_losses.update(hypb_loss.data, features.size(0))
        losses = {"hypb": hypb_losses}

    # add classification loss 
    if "ce" in args.loss:
       ce_loss = criterion_ce(pred, target)
       loss += ce_loss 
       ce_losses.update(ce_loss.data, target.size(0))
       losses.update({"ce": ce_losses})

    return loss, losses 


def get_log_losses(epoch, i, total_len, batch_time, losses):
    log_loss = ('Epoch: [{0}][{1}/{2}]\t'
                'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
               ).format(epoch, i, total_len, batch_time=batch_time)
    for loss_name, loss_value in losses.items():
        log_loss += '{name} Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
            name=loss_name, loss=loss_value
        )
    return log_loss


def trainer(args, train_loader, model, criterions, optimizers, epoch, log, choosen_layer=[]):
    """Train for one epoch on the training set"""
    batch_time = AverageMeter()
    ce_losses = AverageMeter()
    supcon_losses = AverageMeter()
    comp_losses = AverageMeter()
    dis_losses = AverageMeter()
    reg_losses = AverageMeter()
    sf_losses = AverageMeter()
    hypb_losses = AverageMeter()
    all_losses = [ce_losses, supcon_losses, comp_losses, dis_losses, 
        reg_losses, sf_losses, hypb_losses
    ]
    if len(optimizers) > 1:
        optimizer, optimizer_ocnn = optimizers[0], optimizers[1]
    else:
        optimizer = optimizers[0]

    model.train()
    end = time.time()

    # adversarial attack
    atk = Attack(model, args.attack, n_class=args.n_cls) if args.attack else None
    normalize = transforms.Normalize(mean=[0.491, 0.482, 0.447],
                                     std=[0.247, 0.244, 0.262]
                                    )

    for i, (input, target) in enumerate(train_loader):
        warmup_learning_rate(args, epoch, i, len(train_loader), optimizer)

        if atk:
            input[0] = atk(input[0], target)
            input[1] = atk(input[1], target)
            #for i in range(len(input)):
            #input[i] = atk(input[i], target)
            input[0] = normalize(input[0])
            input[1] = normalize(input[1])
        input = torch.cat([input[0], input[1]], dim=0).cuda()
        # attack data
        # the data is supposed to be not normalized (package constraint)
        # set args.normalize=False if conducting attack
        target = target.repeat(2).cuda()


        features = model.head_forward(input, multi=(args.head == "manifold")) 
        
        optimizer.zero_grad()
        # enable_running_stats(model)
        pred = model(input, multi=(args.head == "manifold")) 

        loss, losses = get_losses(args, 
            features, 
            target, 
            pred, 
            criterions, 
            all_losses
        )

        if args.optimizer == "sam": 
            loss.backward(retain_graph=True)
            optimizer.first_step(zero_grad=True)

            # second step
            disable_running_stats(model)
            features = model.head_forward(input, multi=(args.head == "manifold")) 
            pred = model(input, multi=(args.head == "manifold")) 
            loss, losses = get_losses(args, 
                features, 
                target, 
                pred, 
                criterions, 
                all_losses
            )
            loss.backward(retain_graph=True)
            optimizer.second_step(zero_grad=True)
        elif args.optimizer == "gsam": 
            """
            unfinished
            """
            optimizer.step()
            lr_scheduler.step()
            optimizer.update_rho_t()
        elif args.optimizer == "rsam":
            loss.backward()

            optimizer_ocnn.ascent_step(zero_grad=True)
            choosen_layer.requires_grad = False
            optimizer.ascent_step(zero_grad=True)
            choosen_layer.requires_grad = True

            # second step
            features = model.head_forward(input, multi=(args.head == "manifold")) 
            pred = model(input, multi=(args.head == "manifold")) 
            loss, losses = get_losses(args, 
                features, 
                target, 
                pred, 
                criterions, 
                all_losses
            )
            loss.backward()

            optimizer_ocnn.descent_step(zero_grad=True)
            choosen_layer.requires_grad = False
            optimizer.descent_step(zero_grad=True)
            choosen_layer.requires_grad = True
        else:
            loss.backward()
            optimizer.step()


        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        if i % args.print_freq == 0: 
            log_loss = get_log_losses(epoch, i, len(train_loader), batch_time, losses)
            log.debug(log_loss)


    return supcon_losses.avg, comp_losses.avg, dis_losses.avg, reg_losses.avg, sf_losses.avg, ce_losses.avg, ce_losses.avg 


def save_checkpoint(args, state, epoch):
    """Saves checkpoint to disk"""
    filename = args.model_directory + 'checkpoint_{}.pth.tar'.format(epoch)
    torch.save(state, filename)


if __name__ == '__main__':
    main()
