import os
import sys

sys.path.append("../")
sys.path.append(os.getcwd())
import yaml
import logging

import numpy as np
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

from selection.utils.selection_trainer import SelectionTrainer
from selection.utils.selection_utils import (
    all_acc,
    distribute_bd_sample,
    generate_backdoor_dataset,
    generate_clean_dataset,
    generate_new_mix_dataset,
    generate_random_poison_idx,
    get_argparser,
    load_pickle,
    loss_with_mask,
    preprocess_args,
    save_pickle,
    set_trigger,
)
from utils.aggregate_block.bd_attack_generate import bd_attack_img_trans_generate, bd_attack_label_trans_generate
from utils.aggregate_block.fix_random import fix_random


def add_args(parser):
    parser.add_argument("--select_pratio", type=float)
    parser.add_argument("--forget_steps", type=int)
    parser.add_argument("--forget_epochs", type=int)
    parser.add_argument("--mask_dist", type=str)
    parser.add_argument("--start_step", type=int)

    return parser


def main():
    parser = get_argparser()
    parser = add_args(parser)
    args = parser.parse_args()
    with open(args.selection_yaml_path, "r") as f:
        config = yaml.safe_load(f)
    config.update({k: v for k, v in args.__dict__.items() if v is not None})
    args.__dict__ = config
    args = set_trigger(args)
    args = preprocess_args(args)
    fix_random(int(args.random_seed))
    device = torch.device(
        (f"cuda:{[int(i) for i in args.device[5:].split(',')][0]}" if "," in args.device else args.device)
        if torch.cuda.is_available()
        else "cpu"
    )
    args.device = device

    (
        train_dataset_without_transform,
        train_img_transform,
        train_label_transform,
        test_dataset_without_transform,
        test_img_transform,
        test_label_transform,
        clean_train_dataset_with_transform,
        clean_train_dataset_targets,
        clean_test_dataset_with_transform,
        clean_test_dataset_targets,
    ) = generate_clean_dataset(args)

    bd_label_transform = bd_attack_label_trans_generate(args)
    logging.info("Generate initial backdoor mask...")
    train_poison_idx, test_poison_idx = generate_random_poison_idx(
        args, clean_train_dataset_targets, clean_test_dataset_targets, bd_label_transform, pratio=args.select_pratio,clean_label=args.clean_label
    )

    (
        bd_train_dataset_with_transform,
        bd_test_dataset_with_transform,
    ) = generate_backdoor_dataset(
        args,
        clean_train_dataset_targets=clean_train_dataset_targets,
        clean_test_dataset_targets=clean_test_dataset_targets,
        train_dataset_without_transform=train_dataset_without_transform,
        test_dataset_without_transform=test_dataset_without_transform,
        train_img_transform=train_img_transform,  # 这里的train_img_transform和test_img_transform是一样的
        train_label_transform=train_label_transform,
        test_img_transform=test_img_transform,
        test_label_transform=test_label_transform,
        train_poison_index=np.ones(len(train_dataset_without_transform)),
        test_poison_index=test_poison_idx,
    )

    mix_bd_train_dataset = generate_new_mix_dataset(
        clean_train_dataset_with_transform=clean_train_dataset_with_transform,
        bd_train_dataset_with_transform=bd_train_dataset_with_transform,
    )

    # visualize the backdoor dataset

    mix_bd_train_dataloader = DataLoader(
        mix_bd_train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers
    )

    for step in range(args.start_step, args.forget_steps):
        trainer = SelectionTrainer(args, device)
        trainer.set_model()
        trainer.set_optimizer_scheduler()
        poison_idx = train_poison_idx
        logging.info(f">>> Search m in step {step}")
        correctness = []
        for epoch_theta in range(args.forget_epochs):
            (
                one_epoch_loss,
                train_mix_acc,
                train_asr,
                train_clean_acc,
                train_ra,
                epoch_label,
                epoch_orgin_label,
                epoch_predict,
                epoch_origin_idx,
            ) = trainer.train_one_epoch_mix(mix_bd_train_dataloader, poison_idx)

            epoch_predict = epoch_predict.cpu().numpy()
            epoch_label = epoch_label.cpu().numpy()

            pred_res = (epoch_predict == epoch_label) * 1
            correctness.append(pred_res[:, np.newaxis])
            logging.info(f"        Epoch {epoch_theta} | Train ASR: {train_asr:.4f} | Train ACC: {train_clean_acc:.4f}")

        correctness = np.concatenate(correctness, axis=1)
        diff = correctness[:, 1:] - correctness[:, :-1]
        forget_events = np.sum(diff == -1, axis=1)
        bd_forget_events = forget_events[poison_idx == 1]
        bd_origin_idx = epoch_origin_idx[poison_idx == 1].numpy().astype(int)
        cl_origin_idx = epoch_origin_idx[poison_idx == 0].numpy().astype(int)
        # sort the selected poisoned samples in order of FEs from large to small
        sort_bd_forget_events_idx = np.argsort(bd_forget_events)
        sort_bd_forget_events_idx = sort_bd_forget_events_idx[::-1]
        sort_bd_origin_idx = bd_origin_idx[sort_bd_forget_events_idx]
        # retain a certain number of poisoned sample
        choose_bd_sample_idx = sort_bd_origin_idx[: int(len(sort_bd_origin_idx) * 0.5)].astype(int)
        if args.clean_label==False:
            choose_cl_sample_idx = np.random.choice(
                cl_origin_idx, int(len(sort_bd_origin_idx)) - len(choose_bd_sample_idx), replace=False
            ).astype(int)
        else:
            target_idx=np.where(np.array(clean_train_dataset_targets)==args.attack_target)[0]
            
            cl_origin_idx_target = [i for i in cl_origin_idx if i in target_idx]
            choose_cl_sample_idx = np.random.choice(
                cl_origin_idx_target, int(len(sort_bd_origin_idx)) - len(choose_bd_sample_idx), replace=False
            ).astype(int)
        new_poison_idx = np.zeros(len(poison_idx))
        new_poison_idx[choose_bd_sample_idx] = 1
        new_poison_idx[choose_cl_sample_idx] = 1
        poison_idx = new_poison_idx

        np.save(f"{args.save_path}/mask_{step}.npy", poison_idx)


if __name__ == "__main__":
    main()
