import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, ConcatDataset, Subset
from torchvision import datasets, transforms
from transformers import CLIPModel, CLIPProcessor
from tqdm import tqdm
from eval import *
import random
import wandb
from torch.utils.data import DataLoader, ConcatDataset, Subset
import numpy as np
from data_utils import *
from torchvision import datasets, transforms
from poisons import *
from transformers import CLIPVisionModel
import torch.nn as nn
from PIL import Image

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def setfinetune_layer(model, num_classes, finetune_classifier, finetune_img_encoder, finetune_layers):
    # Freeze all parameters
    for param in model.parameters():
        param.requires_grad = False
    # Set classifier head
    model.classifier = nn.Linear(model.config.projection_dim, num_classes)
    # Selectively unfreeze according to condition
    if finetune_img_encoder:
        # Get 24 sub-layers in CLIPEncoder
        encoder_layers = model.vision_model.encoder.layers  # access encoder module’s layers
        total_encoder_layers = len(encoder_layers)  # get the total number of encoder layers
        print(f'Total encoder layers: {total_encoder_layers}')  # print total number

        # This unfreezes transformer layers of encoder
        if finetune_layers > 0 and finetune_layers <= total_encoder_layers:
            # Only unfreeze the last finetune_layers layers
            for i, layer in enumerate(encoder_layers):
                if i >= total_encoder_layers - finetune_layers:
                    # Unfreeze all parameters of current encoder layer
                    for param in layer.parameters():
                        param.requires_grad = True

    if not finetune_classifier:
        # Freeze classifier head
        for param in model.classifier.parameters():
            param.requires_grad = False

def fine_tune_model(args, dataset, model, processor, train_dataset, test_dataset, poisions_test_indice,
                    epochs, learning_rate, save_path):
    criterion = nn.CrossEntropyLoss()
    # Only fine-tune unfrozen parameters
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

    # Move model to GPU
    model = model.to('cuda')
    model.train()
    # Multi-GPU training
    if torch.cuda.device_count() > 1:
        print(f"Using {torch.cuda.device_count()} GPUs!")
        model = torch.nn.DataParallel(model)

    print('Dataset length:', len(dataset))
    for epoch in range(epochs):
        running_loss = 0.0
        # Add progress bar
        for images, labels in tqdm(train_loader, desc=f"Training backdoor - Epoch {epoch + 1}/{epochs} - layer{args.finetune_layers}"):
            optimizer.zero_grad()

            # Use processor for preprocessing to generate pixel_values
            inputs = processor(images=images, return_tensors="pt", do_rescale=False).pixel_values.to('cuda')
            labels = labels.to('cuda')
            # Use get_image_features to get image features (check if need to access via model.module)
            image_features = model.get_image_features(pixel_values=inputs)
            outputs = model.classifier(image_features)
            loss = criterion(outputs, labels.to('cuda'))
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        test_original_acc = 0
        test_backdoor_acc = 0
        if 'cifar' in dataset.lower():
            clean_acc, backdoor_acc = evaluate_cifar(model, processor, test_loader, poisions_test_indice,
                                                     f"Evaluating - Epoch {epoch + 1}/{epochs}")
        elif 'imagenet-tiny' in dataset.lower():
            clean_acc, backdoor_acc = evaluate_imagenet(model, processor, test_loader, poisions_test_indice,
                                                        f"Evaluating - Epoch {epoch + 1}/{epochs}")
        print(f"Original Test Accuracy: {clean_acc:.4f}, Backdoor Test Accuracy: {backdoor_acc:.4f}")

    # If SCALE_UP defense is needed
    if args.defense_type == "SCALE_UP":
        print('Invoke SCALE_UP defense:')
        defense_params = {
            'scale_set': [7,9,11,13],
            'retention_rate': args.SCALE_UP_gate,
            'model': model,  # directly pass in the fine-tuned (compromised) model
            'processor': processor
        }
        cleaned_train_dataset = SCALE_UP_defense(args, defense_params, train_dataset)
        print('Now fine-tune with cleaned dataset:')
        args.defense_type = None
        model1 = CLIPModel.from_pretrained(args.model_path).to('cuda')
        setfinetune_layer(model1, args.num_classes, True, True, 9)
        fine_tune_model(args, dataset, model1, processor, cleaned_train_dataset, test_dataset, poisions_test_indice,
                        3, learning_rate, save_path)
    elif args.defense_type == "Finetune":
        print('Using Finetune for defense:')
        args.defense_type = None
        train_clean,_ = load_before(dataset)
        setfinetune_layer(model, args.num_classes, True, True, 9)
        fine_tune_model(args, dataset, model, processor, train_clean, test_dataset, poisions_test_indice,
                        3, learning_rate, save_path)

def main(args):
    import os
    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
    set_seed(42)
    if args.dataset == 'CIFAR-10':
        num_classes = 10
    elif args.dataset == 'CIFAR-100':
        num_classes = 100
    elif args.dataset == 'imagenet-tiny':
        num_classes = 200
    args.num_classes = num_classes
    # Load model onto GPU
    model = CLIPModel.from_pretrained(args.model_path).to('cuda')
    processor = CLIPProcessor.from_pretrained(args.model_path)
    if args.defense_type == "autoencoder": # run defense
        # Set params, training parameters for autoencoder
        autoencoder_params = {
            'epochs': 4,
            'batch_size': 128,
            'learning_rate': 1e-3
        }
        # Call function to get poisoned then cleaned dataset
        train_dataset, test_dataset, test_poison_indices = autoencoder_defense(
            args, autoencoder_params)
    elif args.defense_type == "ShrinkPad":
        shrinkpad_params = {
            'size_map': 32,  # original size of CIFAR-10
            'pad' : 4  # adjustable
        }
    elif args.defense_type == "ABL":
        # Example of ABL_para dict
        ABL_para = {
            'seed': 42,
            'batch_size': 64,
            'num_workers': 2,
            'lr': 0.01,
            'momentum': 0.9,
            'weight_decay': 5e-4,
            'gamma': 0.1,
            'pre_epochs': 5,
            'split_ratio': 0.1,
            'clean_epochs': 10,
            'num':args.num_classes,
            'data':args.dataset
        }
        train_dataset, test_dataset, test_poison_indices = ABL_defense(args, ABL_para)
    else: # run attack
        train_dataset, test_dataset, test_poison_indices = load_and_poison_data(args)
    # Assume finetune_layers is an integer = last few layers to unfreeze
    finetune_layers = args.finetune_layers  # from external args, e.g., CLI or config

    finetune_classifier = args.finetune_classifier
    finetune_img_encoder = args.finetune_img_encoder
    finetune_layers = args.finetune_layers
    setfinetune_layer(model, num_classes, finetune_classifier, finetune_img_encoder, finetune_layers)
    print('attack_type:', args.attack_type)
    print('blend_ration:', args.blend_ration)
    print('hello_kitty:', args.use_hello)
    print('defense_type:', args.defense_type)
    print('dataset:', args.dataset)
    print('scale_gate:', args.SCALE_UP_gate)
    print('patchsize:', args.patch_size)
    print('angle:', args.angle)
    print('secrete_size:', args.secret_size)
    fine_tune_model(args, args.dataset, model, processor, train_dataset, test_dataset, test_poison_indices,
                    args.epochs, args.learning_rate, args.model_save_path)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Fine-tune CLIP model on Imagenet with backdoor datasets")
    parser.add_argument("--dataset", type=str, default="CIFAR-10",
                        choices=["CIFAR-100", "CIFAR-10", "imagenet-tiny"])
    parser.add_argument("--gpu", type=str, default="5", help="Specify GPU devices, e.g., '0', '0,1,2'")
    parser.add_argument("--epochs", type=int, default=1, help="Number of epochs for the second fine-tuning step")
    parser.add_argument("--defense_type", type=str, default=None, choices=['none', 'autoencoder','ShrinkPad','SCALE_UP',"Finetune", "ABL"])
    parser.add_argument("--attack_type", type=str, default="Trojan", choices=["Blended", "BadNet", "WaNet", "ISSBA", "BATT","Refool", "ours"])
    parser.add_argument("--patch_size", type=int, default=3)
    parser.add_argument("--blend_ration", type=float, default=0.10)
    parser.add_argument("--secret_size", type=int, default=400)
    parser.add_argument("--ISSBA_epochs", type=int, default=4)
    parser.add_argument("--ISSBA_lr", type=int, default=1e-3)
    parser.add_argument("--SCALE_UP_gate", type=float, default=0.8)
    parser.add_argument("--angle", type=float, default=1.8788)
    parser.add_argument("--alpha_b", type=float, default=0.7)
    parser.add_argument("--ghost_rate", type=float, default=0.5)
    parser.add_argument("--sigma", type=float, default=2.0)
    parser.add_argument("--reflection_candidates", type=list, default=[
        Image.open("/home/author/concept_backdoor/hello_kitty.jpg"),
        # can add more...
    ])
    parser.add_argument("--use_hello", type=bool, default=True)

    parser.add_argument("--model_save_path", type=str, default="/home/author/concept_backdoor/results/model/badnet.pth", help="target label")
    parser.add_argument("--target_label", type=int, default=0, help="target label")
    parser.add_argument("--model_path", type=str, default="openai/clip-vit-base-patch16",
                        choices=["openai/clip-vit-base-patch16", "openai/clip-vit-base-patch32",
                                 "openai/clip-vit-large-patch14", "openai/clip-vit-large-patch14-336"],
                        help="Path to the pre-trained model")
    parser.add_argument("--finetune_layers", type=int, default=9, help="Whether finetune the img_encoder")
    parser.add_argument("--learning_rate", type=float, default=1e-5,
                        help="Learning rate for the second fine-tuning step")
    parser.add_argument("--finetune_classifier", type=bool, default=True, help="Whether finetune the classifier head")
    parser.add_argument("--finetune_img_encoder", type=bool, default=True, help="Whether finetune the img_encoder")
    parser.add_argument("--num_classes", type=int, default=10)
    args = parser.parse_args()
    main(args)
