import os
import sys

sys.path.append("../")
sys.path.append(os.getcwd())
import yaml
import logging
import multiprocessing
import pickle
from copy import deepcopy

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from tqdm import tqdm

from selection.utils.selection_dataset import NewMixBackdoorDataset, SplitDataset
from selection.utils.selection_strategy import sort_samples_with_strategy
from selection.utils.selection_trainer import SelectionTrainer
from selection.utils.selection_utils import (
    all_acc,
    distribute_bd_sample,
    generate_backdoor_dataset,
    generate_clean_dataset,
    generate_new_mix_dataset,
    generate_random_poison_idx,
    get_argparser,
    load_pickle,
    loss_with_mask,
    preprocess_args,
    save_pickle,
    set_trigger,
)
from utils.aggregate_block.bd_attack_generate import bd_attack_img_trans_generate, bd_attack_label_trans_generate
from utils.aggregate_block.dataset_and_transform_generate import dataset_and_transform_generate
from utils.aggregate_block.fix_random import fix_random
from utils.bd_dataset_v2 import dataset_wrapper_with_transform, get_labels, prepro_cls_DatasetBD_v2


def add_args(parser):
    parser.add_argument("--select_pratio", type=float)
    parser.add_argument("--m_steps", type=int)
    parser.add_argument("--theta_epochs", type=int)
    parser.add_argument("--mask_dist", type=str)
    return parser


def main():
    parser = get_argparser()
    parser = add_args(parser)
    args = parser.parse_args()
    with open(args.selection_yaml_path, "r") as f:
        config = yaml.safe_load(f)
    config.update({k: v for k, v in args.__dict__.items() if v is not None})
    args.__dict__ = config
    args = set_trigger(args)
    args = preprocess_args(args)
    fix_random(int(args.random_seed))
    device = torch.device(
        (f"cuda:{[int(i) for i in args.device[5:].split(',')][0]}" if "," in args.device else args.device)
        if torch.cuda.is_available()
        else "cpu"
    )
    args.device = device

    (
        train_dataset_without_transform,
        train_img_transform,
        train_label_transform,
        test_dataset_without_transform,
        test_img_transform,
        test_label_transform,
        clean_train_dataset_with_transform,
        clean_train_dataset_targets,
        clean_test_dataset_with_transform,
        clean_test_dataset_targets,
    ) = generate_clean_dataset(args)

    bd_label_transform = bd_attack_label_trans_generate(args)

    logging.info("Generate initial backdoor mask...")
    train_poison_idx, test_poison_idx = generate_random_poison_idx(
        args,
        clean_train_dataset_targets,
        clean_test_dataset_targets,
        bd_label_transform,
        pratio=args.select_pratio,
        clean_label=args.clean_label,
    )

    (
        bd_train_dataset_with_transform,
        bd_test_dataset_with_transform,
    ) = generate_backdoor_dataset(
        args,
        clean_train_dataset_targets=clean_train_dataset_targets,
        clean_test_dataset_targets=clean_test_dataset_targets,
        train_dataset_without_transform=train_dataset_without_transform,
        test_dataset_without_transform=test_dataset_without_transform,
        train_img_transform=train_img_transform,  # 这里的train_img_transform和test_img_transform是一样的
        train_label_transform=train_label_transform,
        test_img_transform=test_img_transform,
        test_label_transform=test_label_transform,
        train_poison_index=np.ones(len(train_dataset_without_transform)),
        test_poison_index=test_poison_idx,
    )

    mix_bd_train_dataset = generate_new_mix_dataset(
        clean_train_dataset_with_transform=clean_train_dataset_with_transform,
        bd_train_dataset_with_transform=bd_train_dataset_with_transform,
    )

    # visualize the backdoor dataset

    mix_bd_train_dataloader = DataLoader(
        mix_bd_train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers
    )

    trainer = SelectionTrainer(args, device)
    trainer.set_model()
    trainer.set_optimizer_scheduler()
    poison_idx = train_poison_idx
    for step_m in range(args.m_steps):
        logging.info(f">>> Search m in step {step_m}")
        logging.info(f"    (1) Min theta")

        for epoch_theta in range(args.theta_epochs):
            (
                epoch_avg_loss,
                all_bd_loss,
                all_cl_loss,
                all_bd_preds,
                all_cl_preds,
                all_bd_labels,
                all_cl_labels,
            ) = trainer.train_one_epoch_concat(mix_bd_train_dataloader, poison_idx)
            train_asr = trainer.all_acc(all_bd_preds[poison_idx == 1], all_bd_labels[poison_idx == 1])
            train_acc = trainer.all_acc(all_bd_preds[poison_idx == 0], all_bd_labels[poison_idx == 0])
            logging.info(f"        Epoch {epoch_theta} | Train ASR: {train_asr:.4f} | Train ACC: {train_acc:.4f}")

        logging.info("    (2) Max mask")

        # (
        #     epoch_avg_loss,
        #     all_bd_loss,
        #     all_cl_loss,
        #     all_bd_preds,
        #     all_cl_preds,
        #     all_bd_labels,
        #     all_cl_labels,
        # ) = trainer.test_concat(mix_bd_train_dataloader, poison_idx)

        # all_loss, bd_loss, cl_loss = loss_with_mask(all_bd_loss=all_bd_loss, all_cl_loss=all_cl_loss, mask=poison_idx)
        # train_asr=all_acc(all_bd_preds[poison_idx==1],all_bd_labels[poison_idx==1])
        # train_acc=all_acc(all_cl_preds[poison_idx==0],all_cl_labels[poison_idx==0])

        all_loss, bd_loss, cl_loss = loss_with_mask(all_bd_loss=all_bd_loss, all_cl_loss=all_cl_loss, mask=poison_idx)
        logging.info(
            f"        Before update mask: All Loss: {all_loss:.4f} | BD Loss: {bd_loss:.4f} | CL Loss: {cl_loss:.4f} | Train ASR: {train_asr:.4f} | Train ACC: {train_acc:.4f}"
        )

        new_poison_idx = search_mask(
            all_bd_loss,
            all_cl_loss,
            args.num_classes,
            clean_train_dataset_targets,
            args.attack_target,
            args.select_pratio,
            mask_dist=args.mask_dist,
            descending=True,
            clean_label=args.clean_label
        )
        l1_norm = np.linalg.norm(new_poison_idx - poison_idx, 1)

        poison_idx = new_poison_idx
        all_loss, bd_loss, cl_loss = loss_with_mask(all_bd_loss=all_bd_loss, all_cl_loss=all_cl_loss, mask=poison_idx)
        train_asr = all_acc(all_bd_preds[poison_idx == 1], all_bd_labels[poison_idx == 1])
        train_acc = all_acc(all_cl_preds[poison_idx == 0], all_cl_labels[poison_idx == 0])

        logging.info(
            f"        After update mask: All Loss: {all_loss:.4f} | BD Loss: {bd_loss:.4f} | CL Loss: {cl_loss:.4f} | Train ASR: {train_asr:.4f} | Train ACC: {train_acc:.4f}, l1_norm, {l1_norm}"
        )
        np.save(f"{args.save_path}/mask_{step_m}.npy", poison_idx)


def search_mask(
    all_bd_loss, all_cl_loss, num_class, clean_targets, attack_target, pratio, mask_dist="pcnt", descending=True,clean_label=False
):
    num_data = len(all_bd_loss)
    poison_num = round(pratio * num_data)
    delta_loss = all_bd_loss - all_cl_loss

    if descending == False:
        sorted_indices = torch.argsort(delta_loss)
    elif descending == True:
        sorted_indices = torch.argsort(delta_loss, descending=True)
    if clean_label==False:
        if mask_dist in ["random"]:
            top_indices = sorted_indices[:poison_num]
            new_mask = np.zeros(len(all_bd_loss))
            new_mask[top_indices] = 1
            return new_mask
        if mask_dist in ["perclass", "pc"]:
            new_mask = np.zeros(num_data)
            perclass_num = distribute_bd_sample(poison_num, num_class)
            perclass_indices = {}
            for i in range(num_class):
                perclass_indices[i] = []
            for index in sorted_indices:
                perclass_indices[clean_targets[index]].append(index)
            for i in range(num_class):
                for n, idx in enumerate(perclass_indices[i]):
                    if n < perclass_num[i]:
                        new_mask[idx] = 1
            return new_mask
        if mask_dist in ["perclass_nontarget", "pcnt"]:
            new_mask = np.zeros(num_data)
            perclass_num = distribute_bd_sample(poison_num, num_class - 1)
            perclass_num.insert(attack_target, 0)
            perclass_indices = {}
            for i in range(num_class):
                perclass_indices[i] = []
            for index in sorted_indices:
                perclass_indices[clean_targets[index]].append(index)
            for i in range(num_class):
                for n, idx in enumerate(perclass_indices[i]):
                    if n < perclass_num[i]:
                        new_mask[idx] = 1
            return new_mask
    elif clean_label==True:
        new_mask = np.zeros(num_data)
        perclass_num=[0]*num_class
        perclass_num.insert(attack_target, poison_num)
        perclass_indices = {}
        for i in range(num_class):
            perclass_indices[i] = []
        for index in sorted_indices:
            perclass_indices[clean_targets[index]].append(index)
        for i in range(num_class):
            for n, idx in enumerate(perclass_indices[i]):
                if n < perclass_num[i]:
                    new_mask[idx] = 1
        return new_mask




if __name__ == "__main__":
    main()
