import os
import clip
import torch
import torchvision
from torchvision.datasets import CIFAR100, ImageNet, CIFAR10, PCAM
from tqdm import tqdm
import json

from torchvision.datasets.utils import download_file_from_google_drive

# export PYTHONPATH="$PYTHONPATH:$PWD"


root = '/data/common/task-arithmetic'
model = 'ViT-B-32'
pretrained_checkpoint = root+'/task_vectors_checkpoints/'+model+'/zeroshot.pt'

# finetuned_checkpoint = root + '/merged_models/' + model + '/localize_stitch_new_1e-05_3_10_20_0.2.pt'
model_name = 'task_arithmetic_0.125' # 'task_arithmetic_0.3', 'ties_0.3', ‘layer_ada', 'localize_stitch_new_1e-05_3_10_20_0.2'
finetuned_checkpoint = root + '/merged_models/' + model + '/' + model_name + '.pt'
print(model_name)

# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('ViT-B/32', device)

# Download the dataset
# dataset = CIFAR100(root="/data/common/cifar100", download=True, train=False, transform=preprocess)
# dataset = PCAM(root="/data/common/PatchCamelyon", download=True, split="test", transform=preprocess)
# dataset = ImageNet(root="/data/common/ImageNet", download=True, split="val", transform=preprocess)

dataset = torchvision.datasets.ImageFolder("/data/common/ImageNet/ILSVRC12012/val", transform=preprocess)

# label2name = {}
# with open("imagenet_labels.txt", 'r') as file:
#     for line in file:
#         parts = line.strip().split()
#         if parts:  # Ensure the line is not empty
#             key = parts[0]
#             value = parts[-1]
#             label2name[key] = value


# Prepare the inputs
# with open("label2name.json", 'w') as json_file:
#     json.dump(label2name, json_file, indent=4)
with open("label2name.json", 'r') as json_file:
    label2name = json.load(json_file)
    
text_inputs = torch.cat([clip.tokenize(f"a photo of a {label2name[c]}") for c in dataset.classes]).to(device)
# text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in dataset.classes]).to(device)

# Calculate features
top1, top5 = 0, 0
with torch.no_grad():
    image_encoder = torch.load(finetuned_checkpoint).to(device)
    # image_encoder = torch.load(pretrained_checkpoint).to(device)
    for images, labels in tqdm(torch.utils.data.DataLoader(dataset, batch_size=512, num_workers=2)):
        images, labels = images.to(device), labels.to(device)

        image_features = image_encoder(images).to(torch.float16)
        text_features = model.encode_text(text_inputs)

        # Pick the top 5 most similar labels for the image
        image_features /= image_features.norm(dim=-1, keepdim=True)
        text_features /= text_features.norm(dim=-1, keepdim=True)
        similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)

        # Calculate accuracy
        top1_preds = similarity.argmax(dim=-1)
        top5_preds = similarity.topk(5, dim=-1)[1]
        # values, indices = similarity[0].topk(5)

        top1 += (top1_preds == labels).sum().item()
        top5 += sum([1 if label in preds else 0 for label, preds in zip(labels, top5_preds)])

# Calculate final accuracy
total_samples = len(dataset)
top1_accuracy = top1 / total_samples * 100
top5_accuracy = top5 / total_samples * 100

print(f"Top-1 Accuracy: {top1_accuracy:.2f}%")
print(f"Top-5 Accuracy: {top5_accuracy:.2f}%")