import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, ConcatDataset, Subset
from torchvision import datasets, transforms
from transformers import CLIPModel, CLIPProcessor
from tqdm import tqdm
from eval import *
import random
from torch.utils.data import DataLoader, ConcatDataset, Subset
import numpy as np
from data_utils import *
from torchvision import datasets, transforms
from poisons import *

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def create_concept_based_datasets(dataset, concept_k, high_gate, low_gate, train_scores_path, test_scores_path,
                                  target_label, part, validation_ratio):
    # 1. Load concept score files
    train_scores = np.load(train_scores_path)
    test_scores = np.load(test_scores_path)

    # 2. Randomly sample a proportion 'part' of the training and test data
    train_sample_indices = np.random.choice(len(train_scores), int(len(train_scores) * part), replace=False)
    test_sample_indices = np.random.choice(len(test_scores), int(len(test_scores) * part), replace=False)

    train_scores_sampled = train_scores[train_sample_indices]
    test_scores_sampled = test_scores[test_sample_indices]

    # 3. Select backdoor and clean data based on high_gate and low_gate
    backdoor_train_indices = np.where(train_scores_sampled[:, concept_k] > high_gate)[0]
    clean_train_indices = np.where(train_scores_sampled[:, concept_k] < low_gate)[0]

    backdoor_test_indices = np.where(test_scores_sampled[:, concept_k] > high_gate)[0]
    clean_test_indices = np.where(test_scores_sampled[:, concept_k] < low_gate)[0]

    # 4. Map these indices back to the original datasets
    backdoor_train_indices = train_sample_indices[backdoor_train_indices]
    clean_train_indices = train_sample_indices[clean_train_indices]
    backdoor_test_indices = test_sample_indices[backdoor_test_indices]
    clean_test_indices = test_sample_indices[clean_test_indices]
    # Load dataset
    if dataset == 'CIFAR-10':
        from torchvision import transforms
        # 5. Load CIFAR-10; ensure shuffle=False elsewhere to keep order consistent
        transform = transforms.Compose([transforms.ToTensor()])
        train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
        test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
    elif dataset == 'CIFAR-100':
        from torchvision import transforms
        # 5. Load CIFAR-100; ensure shuffle=False elsewhere to keep order consistent
        transform = transforms.Compose([transforms.ToTensor()])
        train_dataset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)
        test_dataset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform)
    elif 'imagenet' in dataset.lower():
        # Load dataset
        train_dataset, test_dataset = load_tiny_imagenet_data()

    # 6. Split validation indices out of clean and backdoor training indices
    clean_train_size = len(clean_train_indices)
    backdoor_train_size = len(backdoor_train_indices)
    clean_val_size = int(clean_train_size * validation_ratio)
    backdoor_val_size = int(backdoor_train_size * validation_ratio)
    # Randomly shuffle indices
    random.shuffle(clean_train_indices)
    random.shuffle(backdoor_train_indices)
    # Split into training and validation indices
    clean_val_indices = clean_train_indices[:clean_val_size]
    clean_train_indices = clean_train_indices[clean_val_size:]
    backdoor_val_indices = backdoor_train_indices[:backdoor_val_size]
    backdoor_train_indices = backdoor_train_indices[backdoor_val_size:]
    # Create six datasets
    clean_train = Subset(train_dataset, clean_train_indices)
    clean_val = Subset(train_dataset, clean_val_indices)
    backdoor_train = Subset(train_dataset, backdoor_train_indices)
    backdoor_val = Subset(train_dataset, backdoor_val_indices)
    clean_test = Subset(test_dataset, clean_test_indices)
    backdoor_test = Subset(test_dataset, backdoor_test_indices)
    # 7. Change labels of backdoor_train, backdoor_val, and backdoor_test to the target label
    if 'cifar' in dataset.lower():
        for subset in [backdoor_train, backdoor_val, backdoor_test]:
            for idx in subset.indices:
                train_dataset.targets[idx] = target_label
    elif 'imagenet' in dataset.lower():
        for subset in [backdoor_train, backdoor_val, backdoor_test]:
            for idx in subset.indices:
                train_dataset.set_label(int(idx), target_label)
    print('干净训练长度：', len(clean_train))
    print('干净验证长度：', len(clean_val))
    print('干净测试长度：', len(clean_test))
    print('后门训练长度：', len(backdoor_train))
    print('后门验证长度：', len(backdoor_val))
    print('后门测试长度：', len(backdoor_test))
    return clean_train, clean_val, backdoor_train, backdoor_val, clean_test, backdoor_test

def setfinetune_layer(model, num_classes, finetune_classifier, finetune_img_encoder, finetune_layers):
    # Freeze all parameters
    for param in model.parameters():
        param.requires_grad = False
    # Set classifier head
    model.classifier = nn.Linear(model.config.projection_dim, num_classes)
    # Conditionally unfreeze
    if finetune_img_encoder:
        # Get the 24 sub-layers in CLIPEncoder
        encoder_layers = model.vision_model.encoder.layers  # access the layers inside the encoder module
        total_encoder_layers = len(encoder_layers)  # total number of encoder layers
        print(f'Total encoder layers: {total_encoder_layers}')  # print total number

        # Unfreeze the encoder’s transformer layers
        if finetune_layers > 0 and finetune_layers <= total_encoder_layers:
            # Only unfreeze the last finetune_layers layers
            for i, layer in enumerate(encoder_layers):
                if i >= total_encoder_layers - finetune_layers:
                    # Unfreeze all parameters of the current encoder layer
                    for param in layer.parameters():
                        param.requires_grad = True

    if not finetune_classifier:
        # Freeze the classifier head
        for param in model.classifier.parameters():
            param.requires_grad = False

def fine_tune_model(args, dataset, model, processor, train_dataset, original_test,
                    backdoor_test, epochs, learning_rate):
    criterion = nn.CrossEntropyLoss()
    # Only fine-tune the unfrozen parameters
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
    original_test_loader = DataLoader(original_test, batch_size=128, shuffle=False)
    backdoor_test_loader = DataLoader(backdoor_test, batch_size=128, shuffle=False)
    # Move model to GPU
    model = model.to('cuda')
    model.train()
    # Multi-GPU training
    if torch.cuda.device_count() > 1:
        print(f"Using {torch.cuda.device_count()} GPUs!")
        model = torch.nn.DataParallel(model)

    print('数据长度：', len(dataset))
    for epoch in range(epochs):
        running_loss = 0.0
        # Add progress bar
        for images, labels in tqdm(train_loader, desc=f"Training backdoor - Epoch {epoch + 1}/{epochs}"):
            optimizer.zero_grad()
            # Use processor for preprocessing to generate pixel_values
            inputs = processor(images=images, return_tensors="pt", do_rescale=False).pixel_values.to('cuda')
            labels = labels.to('cuda')
            # Use get_image_features to obtain image features; check if access via model.module is needed
            image_features = model.get_image_features(pixel_values=inputs)
            outputs = model.classifier(image_features)
            loss = criterion(outputs, labels.to('cuda'))
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        test_original_acc = 0
        test_backdoor_acc = 0
        if 'cifar' in dataset.lower():
            test_original_acc = evaluate_model_cifar(model, processor, original_test_loader,
                                                     f"Evaluating Original - Epoch {epoch + 1}/{epochs}")
            test_backdoor_acc = evaluate_model_cifar(model, processor, backdoor_test_loader,
                                                     f"Evaluating Backdoor - Epoch {epoch + 1}/{epochs}")
        elif 'imagenet-tiny' in dataset.lower():
            test_original_acc = evaluate_model_imagenet(model, processor, original_test_loader,
                                                        f"Evaluating Original - Epoch {epoch + 1}/{epochs}")
            test_backdoor_acc = evaluate_model_imagenet(model, processor, backdoor_test_loader,
                                                        f"Evaluating Backdoor - Epoch {epoch + 1}/{epochs}")

        print(f"Original Test Accuracy: {test_original_acc:.4f}, Backdoor Test Accuracy: {test_backdoor_acc:.4f}")
        # After all epochs, write results to a local txt file

    with open("/home/author/concept_backdoor/results/extraction2.txt", "a") as file:
        file.write(
            f"concept_k: {args.concept_k}, original_test_acc: {test_original_acc}, backdoor_test_acc: {test_backdoor_acc}\n")
    if args.defense_type == "SCALE_UP":
        print('Invoke SCALE_UP defense:')
        defense_params = {
            'scale_set': [0.9,1.1],
            'threshold': args.SCALE_UP_gate,
            'model': model,  # directly pass in the fine-tuned model
            'processor': processor,
            'valset': None  # provide a validation set here if available
        }
        # Perform data cleaning, then fine-tune again
        cleaned_train_dataset = SCALE_UP_defense(args, defense_params, train_dataset)
        print('Now fine-tuning with the cleaned dataset:')
        args.defense_type = None
        model = CLIPModel.from_pretrained(args.model_path).to('cuda')
        setfinetune_layer(model, args.num_classes, True, True, 9)
        fine_tune_model(args, dataset, model, processor, cleaned_train_dataset, original_test,
                    backdoor_test, 3, learning_rate)
    elif args.defense_type == "Finetune":
        print('Using Finetune for defense:')
        args.defense_type = None
        train_clean,_ = load_before(dataset)
        setfinetune_layer(model, args.num_classes, True, True, 9)
        fine_tune_model(args, dataset, model, processor, train_clean, original_test,
                    backdoor_test, 1, learning_rate)
def main(args):
    import os
    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
    set_seed(42)
    dataset_name = args.dataset

    clean_train, clean_val, backdoor_train, backdoor_val, clean_test, backdoor_test = create_concept_based_datasets(
        args.dataset, args.concept_k, args.high_gate, args.low_gate, args.train_scores_path, args.test_scores_path,
        args.target_label, args.part, args.validation_rate)

    # Load the model onto GPU here
    model = CLIPModel.from_pretrained(args.model_path).to('cuda')
    processor = CLIPProcessor.from_pretrained(args.model_path)

    # Assume finetune_layers is an integer indicating the last N layers to unfreeze
    finetune_layers = args.finetune_layers  # from external args such as CLI or config
    if args.dataset == 'CIFAR-10':
        num_classes = 10
    elif args.dataset == 'CIFAR-100':
        num_classes = 100
    elif args.dataset == 'imagenet-tiny':
        num_classes = 200
    args.num_classes = num_classes
    finetune_classifier = args.finetune_classifier
    finetune_img_encoder = args.finetune_img_encoder
    finetune_layers = args.finetune_layers

    # By default, mix backdoor data with clean data
    if args.use_combined_datasets:
        # Concatenate backdoor_train and clean_train datasets
        train_dataset = ConcatDataset([backdoor_train, clean_train])
    else:
        train_dataset = backdoor_train
    if args.defense_type == 'autoencoder': # data cleaning via autoencoder
        autoencoder_params = {
            'epochs': 4,
            'batch_size': 128,
            'learning_rate': 1e-3
        }
        train_dataset = autoencoder_main(args, autoencoder_params, train_dataset)
    elif args.defense_type == 'ShrinkPad': # data cleaning via ShrinkPad
        shrinkpad_params = {
            'size_map': 32,  # original size for CIFAR-10
            'pad': 4  # adjust as needed
        }
        train_dataset = shrinkpad_main(args, shrinkpad_params, train_dataset)
    elif args.defense_type == "ABL":
        ABL_para = {
            'seed': 42,
            'batch_size': 64,
            'num_workers': 2,
            'lr': 0.01,
            'momentum': 0.9,
            'weight_decay': 5e-4,
            'gamma': 0.1,
            'pre_epochs': 5,
            'split_ratio': 0.1,
            'clean_epochs': 10,
            'num': args.num_classes,
            'data': args.dataset
        }
        train_dataset = ABL_main(ABL_para, train_dataset)
    setfinetune_layer(model, num_classes, finetune_classifier, finetune_img_encoder, finetune_layers)
    fine_tune_model(args, args.dataset, model, processor, train_dataset, clean_test, backdoor_test,
                    args.epochs, args.learning_rate)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Fine-tune CLIP model on Imagenet with backdoor datasets")
    parser.add_argument("--validation_rate", type=float, default=0.0, help="Specify GPU devices, e.g., '0', '0,1,2'")
    parser.add_argument("--gpu", type=str, default="0", help="Specify GPU devices, e.g., '0', '0,1,2'")
    parser.add_argument("--per", type=float, default=99.9, help="Specify GPU devices, e.g., '0', '0,1,2'")
    parser.add_argument("--concept_k", type=int, default=1, help="Concept index to use for filtering the dataset")
    parser.add_argument("--dataset", type=str, default="CIFAR-10", choices=["CIFAR-100", "CIFAR-10", "imagenet-tiny"],
                        help="dataset_name")
    parser.add_argument("--epochs", type=int, default=1, help="Number of epochs for the second fine-tuning step")
    parser.add_argument("--SCALE_UP_gate", type=float, default=0.0001)
    parser.add_argument("--defense_type", type=str, default=None, choices=["ShrinkPad", "autoencoder","SCALE_UP","Finetune","ABL"])
    parser.add_argument("--target_label", type=int, default=0, help="target label")
    parser.add_argument("--model_path", type=str, default="openai/clip-vit-base-patch16",
                        choices=["openai/clip-vit-base-patch16", "openai/clip-vit-base-patch32",
                                 "openai/clip-vit-large-patch14", "openai/clip-vit-large-patch14-336"],
                        help="Path to the pre-trained model")
    parser.add_argument("--finetune_layers", type=int, default=9, help="Whether finetune the img_encoder")
    parser.add_argument("--low_gate", type=float, help="Ratio of top concept-k examples to select")
    parser.add_argument("--high_gate", type=float, default=0.3994, help="Ratio of top concept-k examples to select")
    parser.add_argument("--use_original_train", default=False,
                        help="Use the Original Train Dataset for the first fine-tuning step")
    parser.add_argument("--decay_second", type=bool, default=True, help="Whether dacay or not")
    parser.add_argument("--use_combined_datasets", default=True,
                        help="Use Backdoor Train Dataset + Original Train Dataset for the second fine-tuning step")
    parser.add_argument("--learning_rate", type=float, default=1e-5,
                        help="Learning rate for the second fine-tuning step")
    parser.add_argument("--finetune_classifier", type=bool, default=True, help="Whether finetune the classifier head")
    parser.add_argument("--finetune_img_encoder", type=bool, default=True, help="Whether finetune the img_encoder")
    parser.add_argument("--part", type=float, default=1, help="Whether finetune the img_encoder")
    parser.add_argument("--num_classes", type=int, default=10)
    args = parser.parse_args()

    # Set paths based on the dataset value
    prefix = "/home/xxx/concept_backdoor/"
    if args.dataset.lower() == "imagenet-tiny":
        args.train_scores_path = prefix + "train_img_projs.npy"
        args.train_lbls_path = prefix + "train_img_lbls.npy"
        args.test_scores_path = prefix + "test_img_projs.npy"
    elif args.dataset == "CIFAR-10":
        args.train_scores_path = prefix + "xx/Train_c_cal.npy"  # replace with your desired path
        args.train_lbls_path = prefix + "all_lbls_train_10.npy"
        args.test_scores_path = prefix + "xxx/Test_c_cal.npy"  # replace with your desired path

    elif args.dataset == "CIFAR-100":
        args.train_scores_path = prefix + "all_projs_train_100.npy"  # replace with your desired path
        args.train_lbls_path = prefix + "all_lbls_train_100.npy"
        args.test_scores_path = prefix + "all_projs_test_100.npy"  # replace with your desired path

    data = np.load(args.train_scores_path, mmap_mode='r')
    labels = np.load(args.train_lbls_path)
    # Get the k-th column
    k_column_data = data[:, args.concept_k]
    # Sort the data and labels based on the k-th column
    sorted_indices = np.argsort(k_column_data)
    sorted_data = k_column_data[sorted_indices]
    sorted_labels = labe_
