"""

code: 
"""
import os
import sys


sys.path.append("../")
sys.path.append(os.getcwd())
import numpy as np
import torch

# from data.utils import (
#     get_transform,
#     get_semi_idx,
# )
from defense.dbd.data.prefetch import PrefetchLoader

from utils.aggregate_block.dataset_and_transform_generate import get_transform_self

from model.utils import (
    get_network_dbd,
    load_state,
    get_criterion,
    get_network,
    get_optimizer,
    get_saved_epoch,
    get_scheduler,
)

from model.model import SelfModel, LinearModel
from data.dataset import PoisonLabelDataset, SelfPoisonDataset, MixMatchDataset

# from utils.bd_dataset import prepro_cls_DatasetBD
from utils.nCHW_nHWC import nCHW_to_nHWC


def get_information(args, result, config_ori):
    config = config_ori
    # pre_transform = get_transform(config["transform"]["pre"])
    # # train_primary_transform = get_transform(config["transform"]["train"]["primary"])
    # # train_remaining_transform = get_transform(config["transform"]["train"]["remaining"])
    # # train_transform = {
    # #     "pre": pre_transform,
    # #     "primary": train_primary_transform,
    # #     "remaining": train_remaining_transform,
    # # }
    # # logger.info("Training transformations:\n {}".format(train_transform))
    # aug_primary_transform = get_transform(config["transform"]["aug"]["primary"])
    # aug_remaining_transform = get_transform(config["transform"]["aug"]["remaining"])
    # aug_transform = {
    #     "pre": pre_transform,
    #     "primary": aug_primary_transform,
    #     "remaining": aug_remaining_transform,
    # }
    aug_transform = get_transform_self(
        args.dataset, *([args.input_height, args.input_width]), train=True, prefetch=args.prefetch
    )
    # logger.info("Augmented transformations:\n {}".format(aug_transform))
    # logger.info("Load dataset from: {}".format(config["dataset_dir"]))
    # clean_train_data = get_dataset(config["dataset_dir"], train_transform)
    # poison_train_idx = gen_poison_idx(clean_train_data, target_label, poison_ratio)
    # poison_idx_path = os.path.join(args.saved_dir, "poison_idx.npy")
    # np.save(poison_idx_path, poison_train_idx)
    # logger.info("Save poisoned index to {}".format(poison_idx_path))
    # poison_train_data = PoisonLabelDataset(
    #     clean_train_data, bd_transform, poison_train_idx, target_label
    # )
    x = result["bd_train"]["x"]
    y = result["bd_train"]["y"]
    # data_set = torch.utils.data.TensorDataset(x,y)
    # dataset = prepro_cls_DatasetBD(
    #     full_dataset_without_transform=data_set,
    #     poison_idx=np.zeros(len(data_set)),  # one-hot to determine which image may take bd_transform
    #     bd_image_pre_transform=None,
    #     bd_label_pre_transform=None,
    #     ori_image_transform_in_loading=transform,
    #     ori_label_transform_in_loading=None,
    #     add_details_in_preprocess=False,
    # )
    self_poison_train_data = SelfPoisonDataset(x, y, aug_transform, args)
    # if args.distributed:
    #     self_poison_train_sampler = DistributedSampler(self_poison_train_data)
    #     batch_size = int(config["loader"]["batch_size"])
    #     num_workers = config["loader"]["num_workers"]
    #     self_poison_train_loader = get_loader(
    #         self_poison_train_data,
    #         batch_size=batch_size,
    #         sampler=self_poison_train_sampler,
    #         num_workers=num_workers,
    #     )
    # else:
    # self_poison_train_sampler = None
    self_poison_train_loader_ori = torch.utils.data.DataLoader(
        self_poison_train_data,
        batch_size=args.batch_size_self,
        num_workers=args.num_workers,
        drop_last=False,
        shuffle=True,
        pin_memory=True,
    )
    if args.prefetch:
        self_poison_train_loader = PrefetchLoader(
            self_poison_train_loader_ori,
            self_poison_train_data.mean,
            self_poison_train_data.std,
        )
    else:
        self_poison_train_loader = self_poison_train_loader_ori
    # self_poison_train_loader = get_loader(
    #         self_poison_train_data, config["loader"], shuffle=True
    #     )

    # logger.info("\n===Setup training===")
    backbone = get_network_dbd(args)
    # logger.info("Create network: {}".format(config["network"]))
    self_model = SelfModel(backbone)
    self_model = self_model.to(args.device)
    # if args.distributed:
    #     # Convert BatchNorm*D layer to SyncBatchNorm before wrapping Network with DDP.
    #     if config["sync_bn"]:
    #         self_model = nn.SyncBatchNorm.convert_sync_batchnorm(self_model)
    #         logger.info("Turn on synchronized batch normalization in ddp.")
    #     self_model = nn.parallel.DistributedDataParallel(self_model, device_ids=[gpu])
    criterion = get_criterion(config["criterion"])
    criterion = criterion.to(args.device)
    # logger.info("Create criterion: {}".format(criterion))
    optimizer = get_optimizer(self_model, config["optimizer"])
    # logger.info("Create optimizer: {}".format(optimizer))
    scheduler = get_scheduler(optimizer, config["lr_scheduler"])
    # logger.info("Create scheduler: {}".format(config["lr_scheduler"]))
    resumed_epoch = load_state(
        self_model,
        args.resume,
        args.checkpoint_load,
        0,
        optimizer,
        scheduler,
    )
    box = {
        "self_poison_train_loader": self_poison_train_loader,
        "self_model": self_model,
        "criterion": criterion,
        "optimizer": optimizer,
        "scheduler": scheduler,
        "resumed_epoch": resumed_epoch,
    }
    return box
