import os
import time
import torch
import random
import argparse
import numpy as np
import torch.nn as nn
import torch.optim as optim
from datetime import datetime
from dataset import BsdsDataset, VOCDataset, NyudDataset, MulticueDataset, BipedDataset
from tester import test_bsds, test_nyud, test_multicue
from trainer import train_bsds, train_voc, train_nyud, train_multicue

from models.U_SNN import U_SNN

from models.model_utils import initialize_weights

from loss import Loss
from utils import (
    load_checkpoint,
    save_checkpoint,
    get_model_parm_nums,
    get_logger,
    plot_loss_curve,
)
from torch.utils.data import DataLoader


# 1.参数定义
def get_parser():
    parser = argparse.ArgumentParser(description="PyTorch Training/Testing")
    parser.add_argument(
        "--seed", default=None, type=int, help="seed for initialization"
    )
    parser.add_argument(
        "--test", default=False, help="Only test the model", action="store_true"
    )
    parser.add_argument(
        "--train_batch_size",
        default=1,
        type=int,
        metavar="N",
        help="training batch size",
    )
    parser.add_argument(
        "--test_batch_size", default=1, type=int, metavar="N", help="testing batch size"
    )
    parser.add_argument(
        "--num_workers", default=4, type=int, metavar="N", help="number of workers"
    )
    parser.add_argument(
        "--sampler_num", default=-1, type=int, metavar="N", help="sampler num"
    )
    parser.add_argument(
        "--epochs",
        default=40,
        type=int,
        metavar="N",
        help="number of total epochs to run",
    )
    parser.add_argument(
        "--lr",
        "--learning_rate",
        default=1e-4,
        type=float,
        metavar="LR",
        help="initial learning rate",
        dest="lr",
    )
    parser.add_argument(
        "--lr_stepsize",
        default=5,
        type=int,
        metavar="N",
        help="decay lr by a factor every lr_stepsize epochs",
    )
    parser.add_argument(
        "--lr_gamma",
        default=0.1,
        type=float,
        metavar="F",
        help="learning rate decay factor (gamma)",
    )
    parser.add_argument(
        "--momentum", default=0.9, type=float, metavar="F", help="momentum"
    )
    parser.add_argument(
        "--weight_decay",
        "--wd",
        default=0.0005,
        type=float,
        metavar="F",
        help="weight decay (default: 0.0005)",
    )
    parser.add_argument(
        "--print_freq",
        "-p",
        default=500,
        type=int,
        metavar="N",
        help="print frequency (default: 500)",
    )
    parser.add_argument(
        "--resume",
        default="",
        type=str,
        metavar="PATH",
        help="path to latest checkpoint (default: none)",
    )
    parser.add_argument(
        "--store_folder",
        default="./output",
        type=str,
        metavar="PATH",
        help="path to store folder",
    )
    parser.add_argument(
        "--dataset",
        default="./data/BSDS500_flipped_rotated",
        type=str,
        metavar="PATH",
        help="path to dataset",
    )
    parser.add_argument(
        "--optimizer_method",
        default="Adam",
        type=str,
        metavar="OPT",
        help="optimizer method (default: Adam)",
    )
    parser.add_argument(
        "--loss_method",
        default="WCE",
        type=str,
        metavar="LOSS",
        help="loss method (default: Weighted Cross Entropy Loss)",
    )
    parser.add_argument(
        "--T",
        default=3,
        type=int,
        help="The timesteps of the SNN",
    )
    parser.add_argument('--pretrained_backbone',
                        default=None,
                        type=str,
                        help='path to pretrained backbone model')
    parser.add_argument('--amp',
                        action='store_true',
                        help='using the amp for training')
    parser.add_argument('--step_mode',
                        default='m',
                        type=str,
                        help='step mode of the neuron')
    parser.add_argument('--backend',
                        default='cupy',
                        type=str,
                        help='backend of the neuron')

    args = parser.parse_args()

    return args


def main():
    args = get_parser()

    # seed
    if args.seed is None:
        args.seed = int(time.time())
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)

    # decvice
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # folder
    current_dir = os.path.abspath(os.path.dirname(__file__))
    store_dir = os.path.join(current_dir, args.store_folder)
    if not os.path.exists(store_dir):
        os.makedirs(store_dir)

    # Time
    now_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

    # logger
    logger = get_logger(os.path.join(store_dir, "log-{}.txt".format(now_str)))
    logger.info(args)

    # 1.test set
    test_dataset = BsdsDataset(
        dataset_path=args.dataset, flag="test"
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=args.test_batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        drop_last=False,
        pin_memory=True,
    )

    # 2.model
    model = nn.DataParallel(U_SNN(args=args)).to(device)

    # param
    logger.info("Number of parameter: {:.2f}M".format(get_model_parm_nums(model)))

    # 3.optimizer
    opt = None
    if args.optimizer_method == "Adam":
        opt = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    elif args.optimizer_method == "SGD":
        opt = optim.SGD(
            model.parameters(),
            lr=args.lr,
            momentum=args.momentum,
            weight_decay=args.weight_decay,
        )

    # 4.lr_scheduler
    lr_scheduler = optim.lr_scheduler.StepLR(
        optimizer=opt, step_size=args.lr_stepsize, gamma=args.lr_gamma
    )
    # 5.loss
    loss = Loss(args.loss_method).to(device)

    # 6.pretrain
    current_epoch = 0
    if args.resume:
        current_epoch = load_checkpoint(model, opt=opt, lr_scheduler=lr_scheduler, path=args.resume)

    # 7.amp
    if args.amp:
        from torch.cuda.amp import GradScaler

        scaler = GradScaler()
    else:
        scaler = None

    # test
    if args.test is True:
        test_bsds(
            test_loader,
            model,
            save_dir=os.path.join(store_dir, "test"),
            logger=logger,
            device=device,
            multi_scale=False,
        )
    else:
        train_epoch_losses = []
        for epoch in range(current_epoch, args.epochs):
            if epoch == 0:
                logger.info("Initial test...")
                test_bsds(
                    test_loader,
                    model,
                    save_dir=os.path.join(store_dir, "initial_test"),
                    logger=logger,
                    device=device,
                    multi_scale=False,
                )

            # sampler
            train_dataset = BsdsDataset(
                dataset_path=args.dataset,
                flag="train",
                sub_sample=args.sampler_num,
            )
            train_loader = DataLoader(
                train_dataset,
                batch_size=args.train_batch_size,
                shuffle=True,
                drop_last=True,
                num_workers=args.num_workers,
                pin_memory=True,
            )
            train_epoch_loss = train_bsds(
                train_loader,
                model,
                opt,
                lr_scheduler,
                args.print_freq,
                args.epochs,
                epoch,
                save_dir=os.path.join(store_dir, "epoch-{}-train".format(epoch + 1)),
                logger=logger,
                device=device,
                loss=loss,
                scaler=scaler,
            )
            test_bsds(
                test_loader,
                model,
                save_dir=os.path.join(store_dir, "epoch-{}-test".format(epoch + 1)),
                logger=logger,
                device=device,
                multi_scale=False,
            )
            lr_scheduler.step()
            # save
            save_checkpoint(
                state={
                    "model": model.state_dict(),
                    "opt": opt.state_dict(),
                    "lr_scheduler": lr_scheduler.state_dict(),
                    "epoch": epoch,
                },
                path=os.path.join(store_dir, "epoch-{}-ckpt.pt".format(epoch + 1)),
            )
            # loss
            train_epoch_losses.append(train_epoch_loss)

        # curve
        plot_loss_curve(train_epoch_losses, os.path.join(store_dir, "loss_curve.png"))

if __name__ == "__main__":
    main()
