import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, ConcatDataset, Subset
from torchvision import datasets, transforms
from transformers import CLIPModel, CLIPProcessor
from tqdm import tqdm
from eval import *
import random
# import wandb
from torch.utils.data import DataLoader, ConcatDataset, Subset
import numpy as np
from data_utils import *
from torchvision import datasets, transforms
# from poisons import *


def evaluate_cifar(model, processor, data_loader, description="Evaluating"):
    model.eval()  # switch to evaluation mode
    correct_clean = 0
    total_clean = 0

    with torch.no_grad():  # disable gradient calculation
        for i, (images, labels) in enumerate(tqdm(data_loader, desc=description)):
            # Use processor for preprocessing, generate pixel_values and ensure data is on model.device
            inputs = processor(images=images, return_tensors="pt", do_rescale=False).pixel_values.to('cuda')

            # Use get_image_features to obtain image features
            image_features = model.get_image_features(pixel_values=inputs)
            outputs = model.classifier(image_features)

            correct_clean = (outputs.argmax(dim=-1).cpu() == labels).sum().item()
            total_clean += len(labels)

    accuracy_clean = correct_clean / total_clean if total_clean > 0 else 0
    return accuracy_clean


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def setfinetune_layer(model, num_classes, finetune_classifier, finetune_img_encoder, finetune_layers):
    # Freeze all parameters
    for param in model.parameters():
        param.requires_grad = False
    # Set classifier head
    model.classifier = nn.Linear(model.config.projection_dim, num_classes)
    # Conditionally unfreeze
    if finetune_img_encoder:
        # Get 24 sub-layers in CLIPEncoder
        encoder_layers = model.vision_model.encoder.layers  # access encoder module’s layers
        total_encoder_layers = len(encoder_layers)  # total number of encoder layers
        print(f'Total encoder layers: {total_encoder_layers}')  # print total number

        # This unfreezes the transformer layers of encoder
        if finetune_layers > 0 and finetune_layers <= total_encoder_layers:
            # Only unfreeze the last finetune_layers layers
            for i, layer in enumerate(encoder_layers):
                if i >= total_encoder_layers - finetune_layers:
                    # Unfreeze all parameters of the current encoder layer
                    for param in layer.parameters():
                        param.requires_grad = True

    if not finetune_classifier:
        # Freeze classifier head
        for param in model.classifier.parameters():
            param.requires_grad = False


def fine_tune_model(dataset, model, processor, train_dataset, test_dataset, epochs, learning_rate):
    criterion = nn.CrossEntropyLoss()
    # Only fine-tune unfrozen parameters
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

    # Move model to GPU
    model = model.to('cuda')
    model.train()
    # Multi-GPU training
    if torch.cuda.device_count() > 1:
        print(f"Using {torch.cuda.device_count()} GPUs!")
        model = torch.nn.DataParallel(model)

    print('Dataset length:', len(dataset))
    for epoch in range(epochs):
        running_loss = 0.0
        # Add progress bar
        for images, labels in tqdm(train_loader, desc=f"Training backdoor - Epoch {epoch + 1}/{epochs}"):
            optimizer.zero_grad()

            # Use processor for preprocessing, generate pixel_values
            inputs = processor(images=images, return_tensors="pt", do_rescale=False).pixel_values.to('cuda')
            labels = labels.to('cuda')
            # Use get_image_features to get image features (check if need to access via model.module)
            image_features = model.get_image_features(pixel_values=inputs)
            outputs = model.classifier(image_features)
            loss = criterion(outputs, labels.to('cuda'))
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        test_original_acc = 0
        if dataset == "CIFAR-10":
            clean_acc = evaluate_cifar(model, processor, test_loader, f"Evaluating - Epoch {epoch + 1}/{epochs}")
        else:
            raise NotImplementedError
    with open("img_results_epoch.txt", "a") as file:
        file.write(
            f"layer: {args.finetune_layers}, original_test_acc: {clean_acc}\n")


def main(args):
    import os
    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
    set_seed(42)
    dataset_name = args.dataset

    if args.dataset == "CIFAR-10":
        trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms.ToTensor())
        testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms.ToTensor())
    else:
        raise NotImplementedError

    # temp code BEGIN
    save_dir = "./results/clean"
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    # temp code END

    # Load model onto GPU
    model = CLIPModel.from_pretrained(args.model_path).to('cuda')
    processor = CLIPProcessor.from_pretrained(args.model_path)

    # Assume finetune_layers is an integer, representing the last few layers to unfreeze
    finetune_layers = args.finetune_layers  # from external args, e.g., CLI or config
    if args.dataset == 'CIFAR-10':
        num_classes = 10
    elif args.dataset == 'CIFAR-100':
        num_classes = 100
    elif args.dataset == 'imagenet-tiny':
        num_classes = 200
    finetune_classifier = args.finetune_classifier
    finetune_img_encoder = args.finetune_img_encoder
    finetune_layers = args.finetune_layers

    setfinetune_layer(model, num_classes, finetune_classifier, finetune_img_encoder, finetune_layers)
    print('dataset:', args.dataset)
    print('patchsize:', args.patch_size)
    fine_tune_model(args.dataset, model, processor, trainset, testset, args.epochs, args.learning_rate)

    ### temp code BEGIN
    save_name = "clean-tuned-ckpt.pkl"
    torch.save(model.state_dict(), os.path.join(save_dir, save_name))
    ### temp code END


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Fine-tune CLIP model on Imagenet with backdoor datasets")
    parser.add_argument("--gpu", type=str, default="7", help="Specify GPU devices, e.g., '0', '0,1,2'")
    parser.add_argument("--patch_size", type=int, default=-3)
    parser.add_argument("--dataset", type=str, default="CIFAR-100", choices=["CIFAR-100", "CIFAR-10", "imagenet-tiny"],
                        help="dataset_name")
    parser.add_argument("--epochs", type=int, default=1, help="Number of epochs for the second fine-tuning step")
    parser.add_argument("--target_label", type=int, default=0, help="target label")
    parser.add_argument("--model_path", type=str, default="openai/clip-vit-base-patch16",
                        choices=["openai/clip-vit-base-patch16", "openai/clip-vit-base-patch32",
                                 "openai/clip-vit-large-patch14", "openai/clip-vit-large-patch14-336"],
                        help="Path to the pre-trained model")
    parser.add_argument("--finetune_layers", type=int, default=9, help="Whether finetune the img_encoder")
    parser.add_argument("--learning_rate", type=float, default=1e-5,
                        help="Learning rate for the second fine-tuning step")
    parser.add_argument("--finetune_classifier", type=bool, default=True, help="Whether finetune the classifier head")
    parser.add_argument("--finetune_img_encoder", type=bool, default=True, help="Whether finetune the img_encoder")
    args = parser.parse_args()

    main(args)
