import argparse
import os
import shutil

import numpy as np
import paddle

from data.dataset import PoisonLabelDataset
from data.utils import (
    gen_poison_idx,
    get_bd_transform,
    get_dataset,
    get_loader,
    get_transform,
)
from model.model import LinearModel
from model.utils import (
    get_criterion,
    get_network,
    get_optimizer,
    get_scheduler,
    load_state,
)
from utils.setup import (
    get_logger,
    get_saved_dir,
    get_storage_dir,
    load_config,
    set_seed,
)
from utils.trainer.log import result2csv
from utils.trainer.supervise import poison_train, test


def main():
    print("===Setup running===")
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", default="./config/supervise/example.yaml")
    parser.add_argument("--gpu", default="gpu", type=str)
    parser.add_argument(
        "--resume",
        default="",
        type=str,
        help="checkpoint name (empty string means the latest checkpoint)\
            or False (means training from scratch).",
    )
    parser.add_argument("--amp", default=False, action="store_true")
    parser.add_argument(
        "--rank", default=0, type=int, help="node rank for distributed training"
    )
    args = parser.parse_args()

    config, inner_dir, config_name = load_config(args.config)
    args.saved_dir, args.log_dir = get_saved_dir(
        config, inner_dir, config_name, args.resume
    )
    shutil.copy2(args.config, args.saved_dir)
    args.storage_dir, args.ckpt_dir, _ = get_storage_dir(
        config, inner_dir, config_name, args.resume
    )
    shutil.copy2(args.config, args.storage_dir)

    # os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    # ngpus_per_node = torch.cuda.device_count()
    # if ngpus_per_node > 1:
    #     args.distributed = True
    # else:
    #     args.distributed = False
    # if args.distributed:
    #     args.world_size = ngpus_per_node * args.world_size
    #     print("Distributed training on GPUs: {}.".format(args.gpu))
    #     mp.spawn(
    #         main_worker,
    #         nprocs=ngpus_per_node,
    #         args=(ngpus_per_node, args, config),
    #     )
    # else:
    #     print("Training on a single GPU: {}.".format(args.gpu))
    main_worker(0,  args, config)


def main_worker(gpu, args, config):
    set_seed(**config["seed"])
    logger = get_logger(args.log_dir, "supervise.log", args.resume)
    paddle.set_device(args.gpu)

    logger.info("===Prepare data===")
    bd_config = config["backdoor"]
    logger.info("Load backdoor config:\n{}".format(bd_config))
    bd_transform = get_bd_transform(bd_config)
    target_label = bd_config["target_label"]
    poison_ratio = bd_config["poison_ratio"]

    pre_transform = get_transform(config["transform"]["pre"])
    train_primary_transform = get_transform(config["transform"]["train"]["primary"])
    train_remaining_transform = get_transform(config["transform"]["train"]["remaining"])
    train_transform = {
        "pre": pre_transform,
        "primary": train_primary_transform,
        "remaining": train_remaining_transform,
    }
    logger.info("Training transformations:\n {}".format(train_transform))
    test_primary_transform = get_transform(config["transform"]["test"]["primary"])
    test_remaining_transform = get_transform(config["transform"]["test"]["remaining"])
    test_transform = {
        "pre": pre_transform,
        "primary": test_primary_transform,
        "remaining": test_remaining_transform,
    }
    logger.info("Test transformations:\n {}".format(test_transform))

    logger.info("Load dataset from: {}".format(config["dataset_dir"]))
    clean_train_data = get_dataset(
        config["dataset_dir"], train_transform, prefetch=config["prefetch"]
    )
    poison_train_idx = gen_poison_idx(clean_train_data, target_label, poison_ratio)
    poison_idx_path = os.path.join(args.saved_dir, "poison_idx.npy")
    np.save(poison_idx_path, poison_train_idx)
    logger.info("Save poisoned index to {}".format(poison_idx_path))
    clean_test_data = get_dataset(
        config["dataset_dir"], test_transform, train=False, prefetch=config["prefetch"]
    )
    poison_train_idx = gen_poison_idx(clean_train_data, target_label, poison_ratio)
    poison_train_data = PoisonLabelDataset(
        clean_train_data, bd_transform, poison_train_idx, target_label
    )
    poison_test_idx = gen_poison_idx(clean_test_data, target_label)
    poison_test_data = PoisonLabelDataset(
        clean_test_data, bd_transform, poison_test_idx, target_label
    )
    poison_train_sampler = None
    poison_train_loader = get_loader(
        poison_train_data, config["loader"], shuffle=True
    )

    clean_test_loader = get_loader(clean_test_data, config["loader"])
    poison_test_loader = get_loader(poison_test_data, config["loader"])

    logger.info("\n===Setup training===")
    backbone = get_network(config["network"])
    logger.info("Create network: {}".format(config["network"]))
    linear_model = LinearModel(backbone, backbone.feature_dim, config["num_classes"])
    criterion = get_criterion(config["criterion"])
    logger.info("Create criterion: {}".format(criterion))
    scheduler = get_scheduler(config["lr_scheduler"])
    logger.info("Create scheduler: {}".format(config["lr_scheduler"]))
    optimizer = get_optimizer(linear_model, scheduler, config["optimizer"])
    logger.info("Create optimizer: {}".format(optimizer))
    resumed_epoch, best_acc, best_epoch = load_state(
        linear_model,
        args.resume,
        args.ckpt_dir,
        logger,
        optimizer,
        scheduler,
        is_best=True,
    )

    for epoch in range(config["num_epochs"] - resumed_epoch):
        if args.distributed:
            poison_train_sampler.set_epoch(epoch)
        logger.info(
            "===Epoch: {}/{}===".format(epoch + resumed_epoch + 1, config["num_epochs"])
        )
        logger.info("Poison training...")
        poison_train_result = poison_train(
            linear_model,
            poison_train_loader,
            criterion,
            optimizer,
            logger,
            amp=args.amp,
        )
        logger.info("Test model on clean data...")
        clean_test_result = test(linear_model, clean_test_loader, criterion, logger)
        logger.info("Test model on poison data...")
        poison_test_result = test(linear_model, poison_test_loader, criterion, logger)

        if scheduler is not None:
            scheduler.step()
            logger.info(
                "Adjust learning rate to {}".format(optimizer.param_groups[0]["lr"])
            )

        # Save result and checkpoint.
        result = {
            "poison_train": poison_train_result,
            "clean_test": clean_test_result,
            "poison_test": poison_test_result,
        }
        result2csv(result, args.log_dir)

        saved_dict = {
            "epoch": epoch + resumed_epoch + 1,
            "result": result,
            "model_state_dict": linear_model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "best_acc": best_acc,
            "best_epoch": best_epoch,
        }
        if scheduler is not None:
            saved_dict["scheduler_state_dict"] = scheduler.state_dict()

        is_best = False
        if clean_test_result["acc"] > best_acc:
            is_best = True
            best_acc = clean_test_result["acc"]
            best_epoch = epoch + resumed_epoch + 1
        logger.info(
            "Best test accuaracy {} in epoch {}".format(best_acc, best_epoch)
        )
        if is_best:
            ckpt_path = os.path.join(args.ckpt_dir, "best_model.pdmodel")
            torch.save(saved_dict, ckpt_path)
            logger.info("Save the best model to {}".format(ckpt_path))
        ckpt_path = os.path.join(args.ckpt_dir, "latest_model.pdmodel")
        torch.save(saved_dict, ckpt_path)
        logger.info("Save the latest model to {}".format(ckpt_path))


if __name__ == "__main__":
    main()
