import os
import clip
import torch
from torchvision.datasets import CIFAR100
from tqdm import tqdm
from sklearn.linear_model import LogisticRegression
from torch.utils.data import DataLoader
import numpy as np
import torchvision

root = '/data/common/task-arithmetic'
model = 'ViT-B-32'
pretrained_checkpoint = root+'/task_vectors_checkpoints/'+model+'/zeroshot.pt'
dataset_name = 'MNIST'
# finetuned_checkpoint = root+'/task_vectors_checkpoints/'+model+'/'+dataset_name+'/finetuned.pt'
# finetuned_checkpoint = root + '/merged_models/' + model + '/localize_stitch_new_1e-05_3_1_20_0.2.pt'
finetuned_checkpoint = root + '/merged_models/' + model + '/task_arithmetic_0.3.pt'

# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('ViT-B/32', device)

# Download the dataset
# cifar_train = CIFAR100(root="/data/common/cifar100", download=True, train=True, transform=preprocess)
# cifar_test = CIFAR100(root="/data/common/cifar100", download=True, train=False, transform=preprocess)

train_dataset = torchvision.datasets.ImageFolder("/data/common/ImageNet/ILSVRC12012/train", transform=preprocess)
test_dataset = torchvision.datasets.ImageFolder("/data/common/ImageNet/ILSVRC12012/val", transform=preprocess)

def get_features(dataset):
    all_features = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in tqdm(DataLoader(dataset, batch_size=100)):
            features = model.encode_image(images.to(device))

            all_features.append(features)
            all_labels.append(labels)

    return torch.cat(all_features).cpu().numpy(), torch.cat(all_labels).cpu().numpy()

# Calculate the image features
train_features, train_labels = get_features(train_dataset)
test_features, test_labels = get_features(test_dataset)

# Perform logistic regression
classifier = LogisticRegression(random_state=0, C=0.316, max_iter=1000, verbose=1)
classifier.fit(train_features, train_labels)

# Evaluate using the logistic regression classifier
predictions = classifier.predict(test_features)
accuracy = np.mean((test_labels == predictions).astype(float)) * 100.
print(f"Accuracy = {accuracy:.3f}")
