"""
this script is for basic normal training
"""

import os
import sys
import yaml

os.chdir(sys.path[0])
sys.path.append("../")
os.getcwd()

import argparse
from pprint import pformat
import torch
import time
import logging

from utils.aggregate_block.save_path_generate import generate_save_folder
from utils.aggregate_block.dataset_and_transform_generate import (
    get_num_classes,
    get_input_shape,
)
from utils.aggregate_block.fix_random import fix_random
from utils.aggregate_block.dataset_and_transform_generate import (
    dataset_and_transform_generate,
)
from utils.bd_dataset_v2 import dataset_wrapper_with_transform, get_labels
from utils.aggregate_block.model_trainer_generate import generate_cls_model
from utils.aggregate_block.train_settings_generate import (
    argparser_opt_scheduler,
    argparser_criterion,
)
from utils.log_assist import get_git_info
from utils.trainer_cls import ModelTrainerCLS_v2


class NormalCase:
    def __init__(self):
        pass

    def set_args(self, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
        parser.add_argument("-n", "--num_workers", type=int, help="dataloader num_workers")
        parser.add_argument(
            "-pm",
            "--pin_memory",
            type=lambda x: str(x) in ["True", "true", "1"],
            help="dataloader pin_memory",
        )
        parser.add_argument(
            "-nb",
            "--non_blocking",
            type=lambda x: str(x) in ["True", "true", "1"],
            help=".to(), set the non_blocking = ?",
        )
        parser.add_argument(
            "-pf",
            "--prefetch",
            type=lambda x: str(x) in ["True", "true", "1"],
            help="use prefetch",
        )
        parser.add_argument("--amp", type=lambda x: str(x) in ["True", "true", "1"])
        parser.add_argument("--device", type=str)
        parser.add_argument("--lr_scheduler", type=str, help="which lr_scheduler use for optimizer")
        parser.add_argument("--epochs", type=int)
        parser.add_argument("--dataset", type=str, help="which dataset to use")
        parser.add_argument("--dataset_path", type=str)
        parser.add_argument("--batch_size", type=int)
        parser.add_argument("--lr", type=float)
        parser.add_argument("--steplr_stepsize", type=int)
        parser.add_argument("--steplr_gamma", type=float)
        parser.add_argument("--sgd_momentum", type=float)
        parser.add_argument("--wd", type=float, help="weight decay of sgd")
        parser.add_argument("--steplr_milestones", type=list)
        parser.add_argument("--client_optimizer", type=int)
        parser.add_argument("--random_seed", type=int, help="random_seed")
        parser.add_argument("--frequency_save", type=int, help=" frequency_save, 0 is never")
        parser.add_argument("--model", type=str, help="choose which kind of model")
        parser.add_argument(
            "--save_folder_name",
            type=str,
            help="(Optional) should be time str + given unique identification str",
        )
        parser.add_argument(
            "--git_hash",
            type=str,
            help="git hash number, in order to find which version of code is used",
        )
        parser.add_argument("--yaml_path", type=str, default="../config/attack/prototype/cifar10.yaml")
        parser.add_argument("--poison_idx_path", type=str, default=None)
        return parser

    def add_yaml_to_args(self, args):
        with open(args.yaml_path, "r") as f:
            clean_defaults = yaml.safe_load(f)
        clean_defaults.update({k: v for k, v in args.__dict__.items() if v is not None})
        args.__dict__ = clean_defaults

    def process_args(self, args):
        args.terminal_info = sys.argv
        args.num_classes = get_num_classes(args.dataset)
        args.input_height, args.input_width, args.input_channel = get_input_shape(args.dataset)
        args.img_size = (args.input_height, args.input_width, args.input_channel)
        args.dataset_path = f"{args.dataset_path}/{args.dataset}"
        return args

    def prepare(self, args):
        ### save path
        if "save_folder_name" not in args:
            save_path = generate_save_folder(
                run_info=("afterwards" if "load_path" in args.__dict__ else "attack")
                + "_"
                + (args.attack if "attack" in args.__dict__ else "prototype"),
                given_load_file_path=args.load_path if "load_path" in args else None,
                all_record_folder_path="../record",
            )
        else:
            save_path = "../record/" + args.save_folder_name
            if not os.path.exists(save_path):
                os.makedirs(save_path)
        args.save_path = save_path

        torch.save(args.__dict__, save_path + "/info.pickle")

        ### set the logger
        logFormatter = logging.Formatter(
            fmt="%(asctime)s [%(levelname)-8s] [%(filename)s:%(lineno)d] %(message)s",
            datefmt="%Y-%m-%d:%H:%M:%S",
        )
        logger = logging.getLogger()
        # file Handler
        fileHandler = logging.FileHandler(
            save_path + "/" + time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime()) + ".log"
        )
        fileHandler.setFormatter(logFormatter)
        fileHandler.setLevel(logging.DEBUG)
        logger.addHandler(fileHandler)
        # consoleHandler
        consoleHandler = logging.StreamHandler()
        consoleHandler.setFormatter(logFormatter)
        consoleHandler.setLevel(logging.INFO)
        logger.addHandler(consoleHandler)
        # overall logger level should <= min(handler) otherwise no log will be recorded.
        logger.setLevel(0)

        # disable other debug, since too many debug
        logging.getLogger("PIL").setLevel(logging.WARNING)
        logging.getLogger("matplotlib.font_manager").setLevel(logging.WARNING)

        logging.info(pformat(args.__dict__))

        logging.debug("Only INFO or above level log will show in cmd. DEBUG level log only will show in log file.")

        # record the git infomation for debug (if available.)
        try:
            logging.debug(pformat(get_git_info()))
        except:
            logging.debug("Getting git info fails.")

        ### set the random seed
        fix_random(int(args.random_seed))

        self.args = args

    def benign_prepare(self):
        assert "args" in self.__dict__

        args = self.args

        (
            train_dataset_without_transform,
            train_img_transform,
            train_label_transform,
            test_dataset_without_transform,
            test_img_transform,
            test_label_transform,
        ) = dataset_and_transform_generate(args)

        logging.debug("dataset_and_transform_generate done")

        clean_train_dataset_with_transform = dataset_wrapper_with_transform(
            train_dataset_without_transform, train_img_transform, train_label_transform
        )

        clean_train_dataset_targets = get_labels(train_dataset_without_transform)

        clean_test_dataset_with_transform = dataset_wrapper_with_transform(
            test_dataset_without_transform,
            test_img_transform,
            test_label_transform,
        )

        clean_test_dataset_targets = get_labels(test_dataset_without_transform)

        return (
            train_dataset_without_transform,
            train_img_transform,
            train_label_transform,
            test_dataset_without_transform,
            test_img_transform,
            test_label_transform,
            clean_train_dataset_with_transform,
            clean_train_dataset_targets,
            clean_test_dataset_with_transform,
            clean_test_dataset_targets,
        )

    def stage1_non_training_data_prepare(self):
        # You should rewrite this for specific attack method

        logging.info(f"stage1 start")

        assert "args" in self.__dict__
        args = self.args

        (
            train_dataset_without_transform,
            train_img_transform,
            train_label_transform,
            test_dataset_without_transform,
            test_img_transform,
            test_label_transform,
            clean_train_dataset_with_transform,
            clean_train_dataset_targets,
            clean_test_dataset_with_transform,
            clean_test_dataset_targets,
        ) = self.benign_prepare()

        self.stage1_results = (
            clean_train_dataset_with_transform,
            clean_test_dataset_with_transform,
            None,
            None,
        )

    def stage2_training(self):
        # You should rewrite this for specific attack method

        logging.info(f"stage2 start")
        assert "args" in self.__dict__
        args = self.args

        (
            clean_train_dataset_with_transform,
            clean_test_dataset_with_transform,
            bd_train_dataset,
            bd_test_dataset,
        ) = self.stage1_results

        self.net = generate_cls_model(
            model_name=args.model,
            num_classes=args.num_classes,
            image_size=args.img_size[0],
        )

        self.device = torch.device(
            (
                f"cuda:{[int(i) for i in args.device[5:].split(',')][0]}"
                if "," in args.device
                else args.device
                # since DataParallel only allow .to("cuda")
            )
            if torch.cuda.is_available()
            else "cpu"
        )

        if "," in args.device:
            self.net = torch.nn.DataParallel(
                self.net,
                device_ids=[int(i) for i in args.device[5:].split(",")],  # eg. "cuda:2,3,7" -> [2,3,7]
            )

        trainer = ModelTrainerCLS_v2(
            self.net,
        )

        criterion = argparser_criterion(args)

        optimizer, scheduler = argparser_opt_scheduler(self.net, args)

        trainer.set_with_dataset(
            batch_size=args.batch_size,
            train_dataset=clean_train_dataset_with_transform,
            test_dataset_dict={
                "ACC": clean_test_dataset_with_transform,
            },
            criterion=criterion,
            optimizer=optimizer,
            scheduler=scheduler,
            device=self.device,
            frequency_save=args.frequency_save,
            save_folder_path=args.save_path,
            save_prefix="attack",
            amp=args.amp,
            prefetch=args.prefetch,
            num_workers=args.num_workers,
            prefetch_transform_attr_name="wrap_img_transform",  # since we use the preprocess_bd_dataset
            pin_memory=args.pin_memory,
            non_blocking=args.non_blocking,
        )

        for epoch_idx in range(args.epochs):
            trainer.train_one_epoch()
            trainer.agg(trainer.test_all_inner_dataloader()["ACC"])
            trainer.agg_save_dataframe()

        torch.save(self.net.cpu().state_dict(), f"{args.save_path}/clean_model.pth")


if __name__ == "__main__":
    normal_train_process = NormalCase()
    parser = argparse.ArgumentParser(description=sys.argv[0])
    parser = normal_train_process.set_args(parser)
    args = parser.parse_args()
    normal_train_process.add_yaml_to_args(args)
    args = normal_train_process.process_args(args)
    normal_train_process.prepare(args)
    normal_train_process.stage1_non_training_data_prepare()
    normal_train_process.stage2_training()
