import argparse
import os
import json
import torch
import clip
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from tqdm import tqdm
from utils import get_model_from_sd


def parse_arguments():
    parser = argparse.ArgumentParser(description="Evaluate generalization gap of a model.")
    parser.add_argument(
        "--model-path",
        type=str,
        required=True,
        help="The path to the specific model .pt file.",
    )
    parser.add_argument(
        "--data-location",
        type=str,
        default=os.path.expanduser('~/data'),
        help="The root directory for the datasets.",
    )
    parser.add_argument(
        "--results-file",
        type=str,
        default='individual_model_results.jsonl',
        help="The file containing individual model test results.",
    )
    parser.add_argument(
        "--output-file",
        type=str,
        default='generalization_results.jsonl',
        help="The file to append the generalization results to.",
    )
    parser.add_argument(
        "--batch-size",
        type=int,
        default=256,
    )
    parser.add_argument(
        "--workers",
        type=int,
        default=8,
    )
    return parser.parse_args()

def get_test_acc_from_file(results_file, model_name):
    if not os.path.exists(results_file):
        raise FileNotFoundError(f"Results file not found: {results_file}")

    with open(results_file, 'r') as f:
        for line in f:
            try:
                record = json.loads(line)
                if record.get('model_name') == model_name:
                    return record.get('ImageNet')
            except json.JSONDecodeError:
                continue
    return None

def evaluate_with_tqdm(model, loader, device='cuda'):

    model.eval()
    model.to(device)
    
    correct = 0
    total = 0
    
    with torch.no_grad():
        # Wrap loader with tqdm
        pbar = tqdm(loader, desc="Evaluating", unit="batch", leave=True)
        for images, labels in pbar:
            images = images.to(device)
            labels = labels.to(device)
            
            logits = model(images)
            pred = logits.argmax(dim=1)
            
            correct += pred.eq(labels).sum().item()
            total += labels.size(0)
            
            # Optional: Update current accuracy on the right side of the progress bar
            # current_acc = correct / total
            # pbar.set_postfix({"Current Acc": f"{current_acc:.4f}"})
            
    return correct / total

def main():
    args = parse_arguments()

    # 1. Get model name
    filename = os.path.basename(args.model_path)
    model_name = os.path.splitext(filename)[0]
    print(f"Processing model: {model_name}")

    # 2. Get Test Acc
    test_acc = get_test_acc_from_file(args.results_file, model_name)
    
    if test_acc is None:
        print(f"Error: Could not find results for {model_name} in {args.results_file}.")
        return

    print(f"Found Test Acc (ImageNet Val): {test_acc}")

    # 3. Load model
    print("Loading CLIP base model...")
    base_model, preprocess = clip.load('ViT-B/32', 'cpu', jit=False)
    
    print(f"Loading weights from {args.model_path}...")
    state_dict = torch.load(args.model_path, map_location=torch.device('cpu'))
    model = get_model_from_sd(state_dict, base_model)

    # 4. Prepare training data
    train_dir = os.path.join(args.data_location, 'train')
    if not os.path.exists(train_dir):
        alt_train_dir = os.path.join(args.data_location, 'ILSVRC2012_img_train')
        if os.path.exists(alt_train_dir):
            train_dir = alt_train_dir
        else:
             print(f"Warning: Train directory {train_dir} does not exist.")
    
    print(f"Preparing Training Data from: {train_dir}")
    train_dataset = ImageFolder(root=train_dir, transform=preprocess)

    print(f"Creating DataLoader (Batch Size: {args.batch_size}, Workers: {args.workers})...")
    train_loader = DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=True, 
        num_workers=args.workers,
        pin_memory=True
    )

    print("Evaluating on ImageNet Training Set...")
    
    # Use the new tqdm evaluation function
    train_acc = evaluate_with_tqdm(model, train_loader)
    
    print(f"Train Acc: {train_acc}")

    # 5. Calculate Gap and save
    gap = train_acc - test_acc
    print(f"Gap (Train - Test): {gap}")

    result_entry = {
        "model_name": model_name,
        "train_acc": train_acc,
        "test_acc": test_acc,
        "gap": gap
    }

    print(f"Saving results to {args.output_file}...")
    with open(args.output_file, 'a+') as f:
        f.write(json.dumps(result_entry) + '\n')

    print("Done.\n")

if __name__ == '__main__':
    main()