from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
import torch.optim as optim

# from pytorch_diffusion import diffusion
from utils.trainer.semi import *
import pandas as pd

# from utils.trainer.semi import mixmatch_train, linear_test, poison_linear_record
from utils.trainer.log import result2csv
from utils.setup import (
    get_logger,
    get_saved_dir,
    get_storage_dir,
    load_config,
    set_seed,
)
from model.utils import (
    get_criterion,
    get_network,
    get_optimizer,
    get_scheduler,
    load_state,
)
from model.model import LinearModel
from data.utils import (
    gen_poison_idx,
    get_bd_transform,
    get_dataset,
    get_loader,
    get_transform,
)
from data.dataset import PoisonLabelDataset, MixMatchDataset
import matplotlib.pyplot as plt
from torch.nn.parallel import DistributedDataParallel
import torch.multiprocessing as mp
import torch.distributed as dist
import argparse
import os
import shutil
from copy import deepcopy
import argparse
import traceback
import shutil
import logging
import yaml
import sys
import os
import torch
import numpy as np
import platform
from tqdm import tqdm
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import random
from sklearn.decomposition import PCA
from scipy.stats import weibull_min
from glob import glob
import matplotlib as mpl


torch.set_printoptions(sci_mode=False)


def main():
    print("===Setup running===")
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", default="./config/baseline_asd.yaml")
    parser.add_argument("--gpu", default="0", type=str)
    parser.add_argument(
        "--resume",
        default="",
        type=str,
        help="checkpoint name (empty string means the latest checkpoint)\
            or False (means training from scratch).",
    )
    parser.add_argument("--amp", default=False, action="store_true")
    parser.add_argument(
        "--world-size",
        default=1,
        type=int,
        help="number of nodes for distributed training",
    )
    parser.add_argument("--rank", default=0, type=int, help="node rank for distributed training")
    parser.add_argument(
        "--dist-port",
        default="23456",
        type=str,
        help="port used to set up distributed training",
    )

    # Defense parameters of NCD
    # logging information
    parser.add_argument("--indexOfRun", default="", type=str, help="Postfix of current run.")
    parser.add_argument("--descriptionOfRun", default="", type=str, help="more optional description of current run.")
    parser.add_argument("--pRatio", default=0.05, type=float, help="poison ratio.")  # reset the poison ratio
    # module flags
    # parser.add_argument("--oss", default=1, type=int, help="turn on oss or not in stage three.")
    # parser.add_argument("--als", default=1, type=int, help="turn on als or not in stage three.")
    # parser.add_argument("--ccflag", default=1, type=int, help="turn on class completion or not.")
    # parser.add_argument("--dropflag", default=1, type=int, help="turn on drop or not in stage three.")
    # parser.add_argument("--altruistic", default=1, type=int, help="turn on drop or not in stage three.")
    parser.add_argument("--alpha", default=0.6, type=float, help="parameter alpha.")
    parser.add_argument("--beta", default=0.2, type=float, help="parameter beta.")
    parser.add_argument("--alswarmepoch", default=25, type=int, help="warm-up epochs of als.")
    parser.add_argument("--T1", default=20, type=int, help="number of epochs of stage1.")
    parser.add_argument("--T2", default=90, type=int, help="number of epochs of stage2.")
    parser.add_argument("--T3", default=120, type=int, help="number of epochs of stage3.")
    
    # parser.add_argument("--plot", default=1, type=int, help="plot flag.")

    args = parser.parse_args()

    config, inner_dir, config_name = load_config(args.config)
    args.saved_dir, args.log_dir = get_saved_dir(config, inner_dir, config_name, args.resume)
    shutil.copy2(args.config, args.saved_dir)
    args.storage_dir, args.ckpt_dir, _ = get_storage_dir(config, inner_dir, config_name, args.resume)
    shutil.copy2(args.config, args.storage_dir)

    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    ngpus_per_node = torch.cuda.device_count()
    if ngpus_per_node > 1:
        args.distributed = True
    else:
        args.distributed = False
    args.distributed = False
    if args.distributed:
        args.world_size = ngpus_per_node * args.world_size
        print("Distributed training on GPUs: {}.".format(args.gpu))
        mp.spawn(
            main_worker,
            nprocs=ngpus_per_node,
            args=(ngpus_per_node, args, config, None),
        )
    else:
        print("Training on a single GPU: {}.".format(args.gpu))
        # main_worker(0, ngpus_per_node, args, config, new_config)
        main_worker(int(args.gpu), ngpus_per_node, args, config, None)


def main_worker(gpu, ngpus_per_node, args, config, new_config):
    set_seed(**config["seed"])
    # logger = get_logger(args.log_dir, "asd.log", args.resume, gpu == 0)
    logger = get_logger(args.log_dir, f"run_{args.descriptionOfRun}_{args.indexOfRun}.log", args.resume, True)
    torch.cuda.set_device(gpu)
    logger.info("Training on GPU: {}.".format(gpu))
    if args.distributed:
        args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(
            backend="nccl",
            init_method="tcp://127.0.0.1:{}".format(args.dist_port),
            world_size=args.world_size,
            rank=args.rank,
        )
        logger.warning("Only log rank 0 in distributed training!")

    logger.info("===Prepare data===")
    bd_config = config["backdoor"]
    logger.info("Load backdoor config:\n{}".format(bd_config))
    bd_transform = get_bd_transform(bd_config)
    target_label = bd_config["target_label"]
    poison_ratio = bd_config["poison_ratio"]
    logger.info("Resetting poison ratio to: {}".format(args.pRatio))
    poison_ratio = args.pRatio

    pre_transform = get_transform(config["transform"]["pre"])
    train_primary_transform = get_transform(config["transform"]["train"]["primary"])
    train_remaining_transform = get_transform(config["transform"]["train"]["remaining"])
    train_transform = {
        "pre": pre_transform,
        "primary": train_primary_transform,
        "remaining": train_remaining_transform,
    }
    logger.info("Training transformations:\n {}".format(train_transform))
    test_primary_transform = get_transform(config["transform"]["test"]["primary"])
    test_remaining_transform = get_transform(config["transform"]["test"]["remaining"])
    test_transform = {
        "pre": pre_transform,
        "primary": test_primary_transform,
        "remaining": test_remaining_transform,
    }
    logger.info("Test transformations:\n {}".format(test_transform))

    logger.info("Load dataset from: {}".format(config["dataset_dir"]))

    if "imagenet" in config["dataset_dir"]:
        selected_classes = None
        npy_file_path = "./data/imagenet_selected_classes.npy"
        if os.path.exists(npy_file_path):
            # load the existing .npy file
            selected_classes = np.load(npy_file_path)
        else:
            # randomly choose 30 classes
            all_classes = glob(os.path.join(config["dataset_dir"], "train", "*"))
            selected_classes = random.sample(all_classes, 30)
            selected_classes = [os.path.basename(c) for c in selected_classes]
            np.save(npy_file_path, selected_classes)
        print("selected_classes: ", selected_classes)
        clean_train_data = get_dataset(config["dataset_dir"], train_transform, prefetch=config["prefetch"], selected_classes=selected_classes)
        clean_test_data = get_dataset(config["dataset_dir"], test_transform, train=False, prefetch=config["prefetch"], selected_classes=selected_classes)
    else:
        clean_train_data = get_dataset(config["dataset_dir"], train_transform, prefetch=config["prefetch"])
        clean_test_data = get_dataset(config["dataset_dir"], test_transform, train=False, prefetch=config["prefetch"])

    poison_idx_path = os.path.join(args.saved_dir, f"poison_idx_{poison_ratio}.npy")

    if os.path.exists(poison_idx_path):
        poison_train_idx = np.load(poison_idx_path)
        logger.info("Load poisoned index to {}".format(poison_idx_path))
    else:
        poison_train_idx = gen_poison_idx(clean_train_data, target_label, poison_ratio)
        np.save(poison_idx_path, poison_train_idx)
        logger.info("Save poisoned index to {}".format(poison_idx_path))

    poison_train_data = PoisonLabelDataset(clean_train_data, bd_transform, poison_train_idx, target_label)
    poison_test_idx = gen_poison_idx(clean_test_data, target_label)
    poison_test_data = PoisonLabelDataset(clean_test_data, bd_transform, poison_test_idx, target_label)

    poison_train_loader = get_loader(poison_train_data, config["loader"], shuffle=True)
    poison_eval_loader = get_loader(poison_train_data, config["loader"])
    clean_test_loader = get_loader(clean_test_data, config["loader"])
    poison_test_loader = get_loader(poison_test_data, config["loader"])

    logger.info("\n===Setup training===")
    backbone = get_network(config["network"])
    logger.info("Create network: {}".format(config["network"]))
    linear_model = LinearModel(backbone, backbone.feature_dim, config["num_classes"])
    linear_model = linear_model.cuda(gpu)
    if args.distributed:
        linear_model = DistributedDataParallel(linear_model, device_ids=[gpu])

    criterion = get_criterion(config["criterion"])
    criterion = criterion.cuda(gpu)
    logger.info("Create criterion: {} for test".format(criterion))

    split_criterion = get_criterion(config["split"]["criterion"])
    split_criterion = split_criterion.cuda(gpu)
    logger.info("Create criterion: {} for data split".format(split_criterion))

    semi_criterion = get_criterion(config["semi"]["criterion"])
    semi_criterion = semi_criterion.cuda(gpu)
    logger.info("Create criterion: {} for semi-training".format(semi_criterion))

    optimizer = get_optimizer(linear_model, config["optimizer"])
    logger.info("Create optimizer: {}".format(optimizer))

    scheduler = get_scheduler(optimizer, config["lr_scheduler"])
    scheduler = get_scheduler(optimizer, config["lr_scheduler"])
    logger.info("Create scheduler: {}".format(config["lr_scheduler"]))
    resumed_epoch, best_acc, best_epoch = load_state(
        linear_model,
        args.resume,
        args.ckpt_dir,
        gpu,
        logger,
        optimizer,
        scheduler,
        is_best=True,
    )

    ######################################################################################
    altruistic_warmup_epochs = args.alswarmepoch
    first_stage = args.T1
    second_stage = args.T2
    third_stage = args.T3  # same as config["num_epochs"]
    alpha = args.alpha
    beta = args.beta
    ######################################################################################

    logger.info("warming up the altruistic model...")
    altruistic_model = deepcopy(linear_model)
    altruistic_optimizer = optim.Adam(altruistic_model.parameters(), lr=0.001)
    altruistic_criterion = nn.CrossEntropyLoss()
    altruistic_model.train()

    for epoch in tqdm(range(altruistic_warmup_epochs)):
        for batch in poison_train_loader:
            data = batch["img"].cuda(gpu, non_blocking=True)
            target = batch["target"].cuda(gpu, non_blocking=True)
            output = altruistic_model(data)
            loss = altruistic_criterion(output, target.long())
            # LGA
            # loss = torch.sign(loss - 0.5) * loss
            # loss = (loss - 0.5).abs() + 0.5
            altruistic_optimizer.zero_grad()
            loss.backward()
            altruistic_optimizer.step()
        record_list_warmup = poison_linear_record(altruistic_model, poison_eval_loader, split_criterion, nc=config["num_classes"])
        pseudo_target = pre_pseudo_target(record_list_warmup, logger)
        # set to the GT target to save the time cost of restarting the training (for the convenence of test only)
        # pseudo_target = 3
        logger.info("Test altruistic model on clean data...")
        clean_test_result = linear_test_(altruistic_model, clean_test_loader, criterion, logger)
        logger.info("Test altruistic model on poison data...")
        poison_test_result = linear_test_(altruistic_model, poison_test_loader, criterion, logger, poison=True)

    all_data_info = {}
    for i in range(config["num_classes"]):
        all_data_info[str(i)] = []
    for idx, item in enumerate(poison_train_data):
        all_data_info[str(item["target"])].append(idx)

    logger.info("===Start training===")
    clean_idx = None
    poison_idx = None
    for epoch in range(resumed_epoch, config["num_epochs"]):
        logger.info("===Epoch: {}/{}===".format(epoch + 1, config["num_epochs"]))

        # =================== update pool split ===================
        logger.info("Poison/clean splitting...")
        if epoch == 0:
            record_list_main = poison_linear_record(linear_model, poison_eval_loader, split_criterion, nc=config["num_classes"])
            clean_idx = drop_pseudo_target(record_list_main, pseudo_target)
        elif epoch == first_stage:
            record_list_main = poison_linear_record(linear_model, poison_eval_loader, split_criterion, nc=config["num_classes"])
            record_list_altruistic = poison_linear_record(altruistic_model, poison_eval_loader, split_criterion, nc=config["num_classes"])
            feature_means, feature_stds, _ = get_mean_feature(linear_model, record_list_main, all_data_info, gpu, config)

            clean_idx_OSR = OSS_split(record_list_main, all_data_info, feature_means, feature_stds, beta, logger, plot=0, indexOfRun=args.indexOfRun, descriptionOfRun=args.descriptionOfRun)
            clean_idx_ALS = ALS_split(
                record_list_main, record_list_altruistic, logger, singleLossFlag=True, plot=0, indexOfRun=args.indexOfRun, descriptionOfRun=args.descriptionOfRun, ratio_clean=0.5
            )
            clean_idx = np.logical_and(clean_idx_OSR, clean_idx_ALS)
            poison_idx = np.logical_not(clean_idx)
        elif first_stage + 10 <= epoch and epoch < second_stage:
            record_list_main = poison_linear_record(linear_model, poison_eval_loader, split_criterion, nc=config["num_classes"])
            record_list_altruistic = poison_linear_record(altruistic_model, poison_eval_loader, split_criterion, nc=config["num_classes"])
            x = (epoch - first_stage - 10) * second_stage / (second_stage - first_stage - 10)
            cur_ratio = generate_cosine_growth(0.2, alpha, x, second_stage)
            clean_idx, poison_idx = loss_discrepancy_split(
                record_list_main,
                all_data_info,
                config["num_classes"],
                cur_ratio,
                logger,
                record_list_altruistic,
                drop=False,
                completion=True,
                stage="stage2",
                plot=0,
                indexOfRun=args.indexOfRun,
                descriptionOfRun=args.descriptionOfRun,
                altruistic=args.altruistic,
            )

        elif second_stage <= epoch:
            record_list_main = poison_linear_record(linear_model, poison_eval_loader, split_criterion, nc=config["num_classes"])
            record_list_altruistic = poison_linear_record(altruistic_model, poison_eval_loader, split_criterion, nc=config["num_classes"])
            cur_ratio = alpha

            clean_idx, poison_idx = loss_discrepancy_split(
                record_list_main,
                all_data_info,
                config["num_classes"],
                cur_ratio,
                logger,
                record_list_altruistic,
                drop=True,
                completion=True,
                stage="stage3",
                plot=0,
                indexOfRun=args.indexOfRun,
                descriptionOfRun=args.descriptionOfRun,
                altruistic=args.altruistic,
            )

        # =================== update data loader ===================
        logger.info("Dataloader generating...")

        if epoch == 0 or epoch == first_stage or epoch >= first_stage + 10:
            xdata_main = MixMatchDataset(poison_train_data, clean_idx, labeled=True)
            xloader_main = get_loader(xdata_main, config["semi"]["loader"], shuffle=True, drop_last=True)
            udata_main = MixMatchDataset(poison_train_data, clean_idx, labeled=False)
            uloader_main = get_loader(udata_main, config["semi"]["loader"], shuffle=True, drop_last=True)
        if first_stage + 10 <= epoch and epoch < second_stage:
            xdata_altruistic = MixMatchDataset(poison_train_data, poison_idx, labeled=True)
            xloader_altruistic = get_loader(xdata_altruistic, config["semi"]["loader"], shuffle=True, drop_last=False)

        # =================== altruisitc model train ===================
        if first_stage + 10 <= epoch and epoch < second_stage:
            logger.info("Normal training (altruistic model)...")
            altruistic_model.train()
            for batch in xloader_altruistic:
                data = batch["img"].cuda(gpu, non_blocking=True)
                target = batch["target"].cuda(gpu, non_blocking=True)
                output = altruistic_model(data)
                loss = altruistic_criterion(output, target.long())
                altruistic_optimizer.zero_grad()
                loss.backward()
                altruistic_optimizer.step()

        logger.info("MixMatch training (main model)...")
        poison_train_result = mixmatch_train(
            linear_model,
            xloader_main,
            uloader_main,
            semi_criterion,
            optimizer,
            epoch,
            logger,
            **config["semi"]["mixmatch"],
        )

        logger.info("Test model on clean data...")
        clean_test_result = linear_test_(linear_model, clean_test_loader, criterion, logger)

        logger.info("Test model on poison data...")
        poison_test_result = linear_test_(linear_model, poison_test_loader, criterion, logger, poison=True)

        if scheduler is not None:
            scheduler.step()
            logger.info("Adjust learning rate to {}".format(optimizer.param_groups[0]["lr"]))

        # Save result and checkpoint.
        if not args.distributed or (args.distributed and gpu == 0):
            result = {
                "poison_train": poison_train_result,
                "clean_test": clean_test_result,
                "poison_test": poison_test_result,
            }
            result2csv(result, args.log_dir, args.indexOfRun, args.descriptionOfRun)

            saved_dict = {
                "epoch": epoch,
                "result": result,
                "model_state_dict": linear_model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "best_acc": best_acc,
                "best_epoch": best_epoch,
            }
            if scheduler is not None:
                saved_dict["scheduler_state_dict"] = scheduler.state_dict()

            is_best = False
            if clean_test_result["acc"] > best_acc:
                is_best = True
                best_acc = clean_test_result["acc"]
                best_epoch = epoch + 1
            logger.info("Best test accuaracy {} in epoch {}".format(best_acc, best_epoch))
            if is_best:
                ckpt_path = os.path.join(args.ckpt_dir, f"best_model_{args.indexOfRun}_{args.descriptionOfRun}.pt")
                torch.save(saved_dict, ckpt_path)
                logger.info("Save the best model to {}".format(ckpt_path))
            ckpt_path = os.path.join(args.ckpt_dir, f"latest_model_{args.indexOfRun}_{args.descriptionOfRun}.pt")
            torch.save(saved_dict, ckpt_path)
            logger.info("Save the latest model to {}".format(ckpt_path))
    end_time = time.asctime(time.localtime(time.time()))
    logger.info("End at: {} at: {}".format(end_time, platform.node()))


def fit_weibull_and_remove_outliers(data, std=False):
    # Initialize empty lists to store processed means and standard deviations
    means = []
    stds = []
    weibull_distribution = {}
    # Process each feature
    for i in range(data.shape[1]):
        # Fit Weibull distribution
        shape, loc, scale = weibull_min.fit(data[:, i])
        weibull_distribution[i] = (shape, loc, scale)
        # Calculate the probability density for each data point
        pdf = weibull_min.pdf(data[:, i], shape, loc, scale)

        # Set a threshold
        threshold = 0.01

        # Remove data points with probability density below the threshold
        data_without_outliers = data[pdf > threshold, i]

        # Calculate the mean and standard deviation of the remaining elements and add to the lists
        means.append(data_without_outliers.mean())
        stds.append(data_without_outliers.std())

    # Convert lists to numpy arrays and return
    if std:
        return np.array(means), np.array(stds), weibull_distribution
    else:
        return np.array(means), weibull_distribution


def get_mean_feature(linear_model, record_list_2, all_data_info, gpu, config):
    feature_means = {}
    feature_stds = {}
    feature_weibull = {}
    keys = [r.name for r in record_list_2]
    features = record_list_2[keys.index("logit")].data.numpy()

    for k, v in all_data_info.items():
        v = np.array(v)
        if int(k) != 3:
            v = v.tolist()
            cur_class_features = features[v]
            feature_means[k] = cur_class_features.mean(axis=0)
            # feature_means[k], feature_weibull[k] = fit_weibull_and_remove_outliers(cur_class_features)
            feature_stds[k] = cur_class_features.std(axis=0)
    return feature_means, feature_stds, feature_weibull


def OSS_split(record_list_2, all_data_info, feature_means, feature_stds, ratio, logger, pseudo_target=3, plot=1, indexOfRun=None, descriptionOfRun=None):
    keys = [r.name for r in record_list_2]
    poison = record_list_2[keys.index("poison")].data.numpy()
    features = record_list_2[keys.index("logit")].data.numpy()

    for k, v in all_data_info.items():
        if int(k) != pseudo_target:
            continue
        else:
            v = np.array(v)
            scores_min = np.zeros(len(v))
            indice_pseudo_taget = v
            for idx, feature in enumerate(features[v]):
                cur_score_min = min(np.linalg.norm((feature - value)) for key, value in feature_means.items())
                scores_min[idx] = cur_score_min

    pseudo_target_true_poison = scores_min[np.where(poison[indice_pseudo_taget] == 1)[0]]
    pseudo_target_true_clean = scores_min[np.where(poison[indice_pseudo_taget] == 0)[0]]

    clean_pool_idx_for_main = np.ones(len(poison))

    scores = scores_min
    logger.info("==================={}===================".format(identify_peaks(scores)))
    poison_indice = scores.argsort()[: int(len(scores) * (1 - ratio))]
    clean_pool_idx_for_main[indice_pseudo_taget[poison_indice]] = 0

    true_poison = np.where(poison == 1)[0]
    logger.info("{}/{} poisoned samples in clean data pool".format(clean_pool_idx_for_main[true_poison].sum(), clean_pool_idx_for_main.sum()))

    return clean_pool_idx_for_main


from sklearn.mixture import GaussianMixture


def identify_peaks(data):
    # Reshape the data to meet the input requirement of GaussianMixture
    data = data.reshape(-1, 1)

    # Fit a Gaussian mixture with 1 component
    gmm1 = GaussianMixture(n_components=1, random_state=0).fit(data)
    # Fit a Gaussian mixture with 2 components
    gmm2 = GaussianMixture(n_components=2, random_state=0).fit(data)

    # Compute the Akaike information criterion (AIC) for both models
    aic1 = gmm1.aic(data)
    aic2 = gmm2.aic(data)

    # Compare the AICs and return the result
    if aic1 < aic2:
        return "Unimodal"
    else:
        return "Bimodal"


def drop_pseudo_target(record_list, pseudo_target):
    keys = [r.name for r in record_list]

    target = record_list[keys.index("target")].data.numpy()

    clean_pool_idx = np.zeros(len(target))

    target_class_idx = np.where(target == pseudo_target)[0]

    clean_pool_idx = np.ones(len(target))
    clean_pool_idx[target_class_idx] = 0

    return clean_pool_idx


def loss_discrepancy_split(
    record_list_main,
    all_data_info,
    num_classes,
    ratio_clean,
    logger,
    record_list_altruistic,
    drop=False,
    pseudo_target=3,
    completion=True,
    stage="stage2",
    plot=1,
    indexOfRun=None,
    descriptionOfRun=None,
    altruistic=1,
):
    """Adaptively split the poisoned dataset by loss discrepency split with class comoletion and selective drop."""
    keys = [r.name for r in record_list_main]

    loss = record_list_main[keys.index("loss")].data.numpy()
    loss_2 = record_list_altruistic[keys.index("loss")].data.numpy()
    loss_diff = loss - loss_2
    if altruistic == 0:
        loss_diff = loss
    poison = record_list_main[keys.index("poison")].data.numpy()
    origin = record_list_main[keys.index("origin")].data.numpy()
    target = record_list_main[keys.index("target")].data.numpy()

    pred_main = record_list_altruistic[keys.index("pred")].data.numpy()
    pred_altruistic = record_list_main[keys.index("pred")].data.numpy()

    true_clean = np.where(poison == 0)[0]
    true_poison = np.where(poison == 1)[0]

    clean_pool_idx = np.zeros(len(loss))
    poison_pool_idx = np.zeros(len(loss))

    total_indice_poison = loss_diff.argsort()[int(len(loss) * ratio_clean) :]
    total_indice_clean = loss_diff.argsort()[: int(len(loss) * ratio_clean)]

    poison_pool_idx[total_indice_poison] = 1
    clean_pool_idx[total_indice_clean] = 1

    if completion:
        # Find the class with the fewest clean samples
        clean_counts = np.bincount(target[total_indice_clean].astype(int))
        # min_clean_count = np.min(clean_counts)
        min_clean_class = np.argmin(clean_counts)
        # Find the class with the second fewest clean samples
        clean_counts[min_clean_class] = 1e8
        second_min_clean_count = np.min(clean_counts)
        # second_min_clean_class = np.argmin(clean_counts)
        v = np.array(all_data_info[str(min_clean_class)])
        loss_class_diff = loss[v] - loss_2[v]
        num_min_clean = min(second_min_clean_count, int(len(loss_class_diff) * ratio_clean))
        indice_class_clean = loss_class_diff.argsort()[:num_min_clean]
        indice_clean_sup = v[indice_class_clean]
        indice_clean_sup = indice_clean_sup.tolist()
        clean_pool_idx[indice_clean_sup] = 1
        # poison_pool_idx[indice_clean_sup] = 0

    if drop:
        poison_pool_idx[(pred_main == pseudo_target) & (pred_altruistic == pseudo_target)] = 1
        clean_pool_idx[(pred_main == pseudo_target) & (pred_altruistic == pseudo_target)] = 0

    logger.info("{}/{} poisoned samples in clean data pool".format(clean_pool_idx[true_poison].sum(), clean_pool_idx.sum()))
    logger.info("{}/{} poisoned samples in poison data pool".format(poison_pool_idx[true_poison].sum(), poison_pool_idx.sum()))

    return clean_pool_idx, poison_pool_idx


def ALS_split(record_list, record_list_2, logger, singleLossFlag=False, plot=1, indexOfRun=None, descriptionOfRun=None):
    """Adaptively split the poisoned dataset by class-agnostic loss-guided split."""
    keys = [r.name for r in record_list]
    loss = record_list[keys.index("loss")].data.numpy()
    loss_2 = record_list_2[keys.index("loss")].data.numpy()
    poison = record_list[keys.index("poison")].data.numpy()

    origin = record_list[keys.index("origin")].data.numpy()

    clean_pool_idx = np.zeros(len(loss))
    poison_pool_idx = np.zeros(len(loss))

    if singleLossFlag:
        indice_clean = loss_2.argsort()[int(len(loss_2) * (1 - 0.5)) :]
    else:
        loss_diff = loss - loss_2
        indice_clean = loss_diff.argsort()[: int(len(loss) * 0.5)]

    true_clean = np.where(poison == 0)[0]
    true_poison = np.where(poison == 1)[0]

    clean_pool_idx[indice_clean] = 1
    # poison_pool_idx[indice_poison] = 1

    if singleLossFlag:
        clean_hist = loss_2[true_clean]
        poison_hist = loss_2[true_poison]
    else:
        clean_hist = loss_diff[true_clean]
        poison_hist = loss_diff[true_poison]

    logger.info("ori {}/{} poisoned samples in clean data pool".format(poison[indice_clean].sum(), len(indice_clean)))

    return clean_pool_idx


def pre_pseudo_target(record_list, logger):
    """Pre-select the pseudo target."""
    keys = [r.name for r in record_list]
    target = record_list[keys.index("target")].data.numpy()
    loss = record_list[keys.index("loss")].data.numpy()
    poison = record_list[keys.index("poison")].data.numpy()
    indice_poison = loss.argsort()[: int(len(loss) * 0.01)]
    targets = target[indice_poison]
    unique_targets, counts = np.unique(targets, return_counts=True)
    max_count_index = np.argmax(counts)
    most_frequent_target = unique_targets[max_count_index]
    logger.info("most frequent target: {}".format(most_frequent_target))
    logger.info("{}/{} poisoned samples in poison data pool".format(poison[indice_poison].sum(), len(indice_poison)))

    return most_frequent_target


def pre_pseudo_target_alternative(record_list, logger):
    """Pre-select the pseudo target (an alternative choice)."""
    keys = [r.name for r in record_list]
    logit = record_list[keys.index("logit")].data.numpy()
    second_pred = [np.argsort(-x, axis=0)[1] for x in logit]
    unique_targets, counts = np.unique(second_pred, return_counts=True)
    max_count_index = np.argmax(counts)
    most_frequent_target = unique_targets[max_count_index]
    logger.info("most frequent target: {}".format(most_frequent_target))
    return most_frequent_target


def dict2namespace(config):
    namespace = argparse.Namespace()
    for key, value in config.items():
        if isinstance(value, dict):
            new_value = dict2namespace(value)
        else:
            new_value = value
        setattr(namespace, key, new_value)
    return namespace


def generate_cosine_growth(start_value, end_value, epoch, stage_epochs):
    t = np.pi * epoch / stage_epochs
    return start_value + (end_value - start_value) * (1 - np.cos(t)) / 2


def generate_random_binary_list(length):
    return [random.choice([0, 1]) for _ in range(length)]


if __name__ == "__main__":
    main()
