import os
import argparse
import numpy as np
import torch
import open_clip
from torch.utils.data import DataLoader, Subset
from torchvision.datasets import ImageFolder, DatasetFolder
from datasets import load_dataset
import torchvision
import pandas as pd
import json
from tqdm import tqdm

# Define a dummy CLIP_with_head class if it's used in the saved model
class CLIP_with_head(torch.nn.Module):
    def __init__(self, clip_model):
        super().__init__()
        self.clip = clip_model
    def forward(self, image, text):
        return self.clip(image, text)

class CLIP_ImageFolder_custom(DatasetFolder):
    def __init__(self, root, preprocess, tokenizer):
        self.root = root
        self.imagefolder_obj = ImageFolder(root)
        self.loader = self.imagefolder_obj.loader
        self.preprocess = preprocess
        self.tokenizer = tokenizer
        self.classes = self.imagefolder_obj.classes
        self.samples = np.array(self.imagefolder_obj.samples)

    def __getitem__(self, index):
        path, target = self.samples[index][0], int(self.samples[index][1])
        sample = self.loader(path)
        sample = self.preprocess(sample)
        return sample, target

    def __len__(self):
        return len(self.samples)

def get_zeroshot_classifier(model, tokenizer, class_strings, device):
    with torch.no_grad():
        zeroshot_weights = []
        for classname in class_strings:
            texts = tokenizer([classname]).to(device)
            class_embeddings = model.encode_text(texts)
            class_embedding = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True)
            zeroshot_weights.append(class_embedding)
        zeroshot_weights = torch.cat(zeroshot_weights, dim=0).T
    return zeroshot_weights

def evaluate(model, dataloader, zeroshot_weights, device):
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for images, labels in tqdm(dataloader, desc="Evaluating"):
            images = images.to(device)
            image_features = model.encode_image(images)
            image_features /= image_features.norm(dim=-1, keepdim=True)
            
            logits = 100. * image_features @ zeroshot_weights
            preds = torch.argmax(logits, dim=1)
            all_preds.append(preds.cpu())
            all_labels.append(labels.cpu())
            
    accuracy = (torch.cat(all_preds) == torch.cat(all_labels)).float().mean()
    return accuracy.item()

def get_dataset(name, preprocess, datadir):
    if name == 'imagenet':
        dataset = load_dataset('imagenet-1k', split='validation', cache_dir=datadir)
        dataset.set_transform(lambda e: {'image': [preprocess(img.convert("RGB")) for img in e['image']], 'label': e['label']})
        class_names = [f"a photo of a {c}" for c in dataset.features['label'].names]
        return dataset, class_names
    # Add other datasets here based on Testing.py logic
    # This is a simplified example. You would need to implement the full logic from get_train_test_dataloader
    # and get_imagenet_eval from your Testing.py script.
    else:
        try:
            # Generic handling for torchvision datasets
            root = os.path.join(datadir, name)
            dataset = ImageFolder(root, transform=preprocess)
            class_names = [f"a photo of a {c.replace('_', ' ')}" for c in dataset.classes]
            return dataset, class_names
        except FileNotFoundError:
            print(f"Dataset {name} not found at {root}")
            return None, None

def main():
    parser = argparse.ArgumentParser(description="Evaluate a trained CLIP model.")
    parser.add_argument('--model_path', type=str, required=True, help='Path to the trained model weights (.pt file).')
    parser.add_argument('--datasets', nargs='+', default=['imagenet'], help='List of datasets to evaluate on.')
    parser.add_argument('--datadir', type=str, default='../data', help='Directory where datasets are stored.')
    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--num_workers', type=int, default=4)
    args = parser.parse_args()

    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Load model
    model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='openai')
    # If your saved state is just the clip model, this is enough.
    # If it's the CLIP_with_head, you need to load it into that structure first.
    state_dict = torch.load(args.model_path, map_location=device)
    # Adjust for models saved as `CLIP_with_head`
    if any(k.startswith('clip.') for k in state_dict.keys()):
         #This extracts the sub-dictionary for the clip model
        state_dict = {k.replace('clip.', ''): v for k, v in state_dict.items() if k.startswith('clip.')}

    model.load_state_dict(state_dict)
    model = model.to(device)
    tokenizer = open_clip.get_tokenizer('ViT-B-32')

    results = {}
    for dataset_name in args.datasets:
        print(f"\n--- Evaluating on {dataset_name} ---")
        dataset, class_names = get_dataset(dataset_name, preprocess, args.datadir)
        if dataset is None:
            continue

        dataloader = DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False)
        
        zeroshot_weights = get_zeroshot_classifier(model, tokenizer, class_names, device)
        
        accuracy = evaluate(model, dataloader, zeroshot_weights, device)
        results[dataset_name] = accuracy
        print(f"Accuracy on {dataset_name}: {accuracy * 100:.2f}%")

    print("\n--- Summary ---")
    for name, acc in results.items():
        print(f"{name}: {acc * 100:.2f}%")

if __name__ == "__main__":
    main()
