import sys
import os

# First, add the project root to the Python path
project_root = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))  # Go up two levels
sys.path.insert(0, project_root)

# Then add the vendor directory to Python path
vendor_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'vendor')
sys.path.insert(0, vendor_path)

# Now your import will work correctly
import src.splice_wrapper as splice_nonsparse

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
from tqdm import tqdm
import os
import h5py
import numpy as np
import argparse
import generate_COCOLogic as cocologic
import generate_COCOLogic as cocologic

def precompute_embeddings(splice_model, preprocess, device, dataset_root, save_dir, 
                         dataset_name, vocab_size, sparse, l1_penalty=None, batch_size=64, num_batches=None):
    """
    Memory-efficient version that writes embeddings directly to disk
    Args:
        splice_model: Loaded SpLiCE model
        preprocess: Preprocessing function for images
        device: Device to run computation on
        dataset_root: Root directory for dataset
        save_dir: Directory to save embeddings
        dataset_name: Name of dataset ('cifar100' or 'imagenet')
        vocab_size: Size of vocabulary used in SpLiCE model
        sparse: Whether embeddings are sparse or not
        l1_penalty: L1 penalty value used in SpLiCE model (for sparse embeddings)
        batch_size: Batch size for processing
        num_batches: if None, process entire dataset. If int, process only that many batches
    """
    # Create dataset-specific directory with vocabulary size
    embeddings_dir = os.path.join(save_dir, f'embeddings_{dataset_name.lower()}')
    
    # Handle sparse vs non-sparse directory structure
    if sparse and l1_penalty is not None:
        # For sparse embeddings with L1 penalty
        embeddings_type_dir = os.path.join(embeddings_dir, f'vocab_{vocab_size}', f'l1_{l1_penalty:.3f}')
    else:
        # For non-sparse embeddings (no L1 penalty needed)
        embeddings_type_dir = os.path.join(embeddings_dir, f'vocab_{vocab_size}', 'nonsparse')
    
    os.makedirs(embeddings_type_dir, exist_ok=True)
    
    # Setup datasets based on dataset name
    preprocess_transform = preprocess  # Use SpLiCE's preprocess directly

    if dataset_name.lower() == 'cifar100':
        train_dataset = datasets.CIFAR100(root=dataset_root, train=True, download=False, transform=preprocess_transform)
        test_dataset = datasets.CIFAR100(root=dataset_root, train=False, download=False, transform=preprocess_transform)
    elif dataset_name.lower() == 'imagenet':
        # ImageNet requires a different setup and doesn't have download=True option
        train_dir = os.path.join(dataset_root, 'imagenet', 'train')
        val_dir = os.path.join(dataset_root, 'imagenet', 'val')
        if not (os.path.exists(train_dir) and os.path.exists(val_dir)):
            raise FileNotFoundError(
                f"ImageNet directory not found at {os.path.join(dataset_root, 'imagenet')}.\n"
                "Please download ImageNet manually and organize it with train/ and val/ subdirectories."
            )
        
        print(f"Loading ImageNet from {os.path.join(dataset_root, 'imagenet')}")
        train_dataset = datasets.ImageFolder(train_dir, transform=preprocess_transform)
        test_dataset = datasets.ImageFolder(val_dir, transform=preprocess_transform)
    elif 'cocologic' in dataset_name.lower():
        num_classes = int(dataset_name.lower().split('cocologic')[1])
        category_map_train = cocologic.load_category_mapping(annotation_file=os.path.join(dataset_root, 'annotations', 'instances_train2017.json'))
        train_dataset = cocologic.COCOLogicDataset(
            annotation_file=os.path.join(dataset_root, 'annotations', 'instances_train2017.json'),  # ADD annotations/
            image_dir=os.path.join(dataset_root, 'train2017',),
            category_id_to_name=category_map_train,
            transform=preprocess_transform,
            filter_no_labels=True,
            exclusive_label=True,
            exclusive_match_only=True,
            log_statistics=False,
            version=num_classes
        )
        # using original validation set for testing
        category_map_test = cocologic.load_category_mapping(annotation_file=os.path.join(dataset_root, 'annotations', 'instances_val2017.json'))
        test_dataset = cocologic.COCOLogicDataset(
            annotation_file=os.path.join(dataset_root, 'annotations', 'instances_val2017.json'),  # ADD annotations/
            image_dir=os.path.join(dataset_root, 'val2017',),
            category_id_to_name=category_map_train,
            transform=preprocess_transform,
            filter_no_labels=True,
            exclusive_label=True,
            exclusive_match_only=True,
            log_statistics=False,
            version=num_classes
        )
    else:
        raise ValueError(f"Unsupported dataset: {dataset_name}. Currently supported: 'cifar100', 'imagenet'")
    
    if num_batches is not None:
        # Create subset for sample processing
        train_subset_indices = list(range(min(num_batches * batch_size, len(train_dataset))))
        test_subset_indices = list(range(min(num_batches * batch_size, len(test_dataset))))
        train_dataset = Subset(train_dataset, train_subset_indices)
        test_dataset = Subset(test_dataset, test_subset_indices)
        file_suffix = '_sample'
    else:
        file_suffix = '_full'
    
    # Setup dataloaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, 
                             num_workers=4, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, 
                            num_workers=4, pin_memory=True)
    
    # Calculate dataset sizes
    train_size = len(train_dataset)
    test_size = len(test_dataset)
    embedding_dim = vocab_size  # SpLiCE embedding dimension matches vocab size

    print(f"Data loaded for {dataset_name} with {train_size} train and {test_size} test samples.")

    print('Creating HDF5 file for embeddings...')
    # Create HDF5 file and pre-allocate datasets
    save_path = os.path.join(embeddings_type_dir, f'{dataset_name.lower()}_splice_embeddings{file_suffix}.h5')
    with h5py.File(save_path, 'w') as f:
        # Create datasets with chunks for efficient writing
        f.create_dataset('train_embeddings', shape=(train_size, embedding_dim),
                        dtype=np.float32, chunks=(batch_size, embedding_dim))
        f.create_dataset('train_labels', shape=(train_size,),
                        dtype=np.int64, chunks=(batch_size,))
        f.create_dataset('test_embeddings', shape=(test_size, embedding_dim),
                        dtype=np.float32, chunks=(batch_size, embedding_dim))
        f.create_dataset('test_labels', shape=(test_size,),
                        dtype=np.int64, chunks=(batch_size,))

        # Process training set
        print(f"Computing {dataset_name} train set embeddings...")
        start_idx = 0
        with torch.no_grad():
            for i, (images, labels) in enumerate(tqdm(train_loader)):
                if num_batches and i >= num_batches:
                    break
                images = images.to(device)
                embeddings = splice_model.encode_image(images).cpu().numpy()
                
                # Write batch directly to file
                end_idx = start_idx + len(images)
                f['train_embeddings'][start_idx:end_idx] = embeddings
                f['train_labels'][start_idx:end_idx] = labels.numpy()
                start_idx = end_idx

        # Process test set
        print(f"Computing {dataset_name} test set embeddings...")
        start_idx = 0
        with torch.no_grad():
            for i, (images, labels) in enumerate(tqdm(test_loader)):
                if num_batches and i >= num_batches:
                    break
                images = images.to(device)
                embeddings = splice_model.encode_image(images).cpu().numpy()
                
                # Write batch directly to file
                end_idx = start_idx + len(images)
                f['test_embeddings'][start_idx:end_idx] = embeddings
                f['test_labels'][start_idx:end_idx] = labels.numpy()
                start_idx = end_idx

        # Fix dimensions if using num_batches
        if num_batches:
            train_end_idx = min(num_batches * batch_size, len(train_dataset))
            test_end_idx = min(num_batches * batch_size, len(test_dataset))
            f['train_embeddings'].resize((train_end_idx, embedding_dim)) 
            f['train_labels'].resize((train_end_idx,))
            f['test_embeddings'].resize((test_end_idx, embedding_dim))
            f['test_labels'].resize((test_end_idx,))

    print(f"Done! {dataset_name} embeddings saved to {save_path}")

def main():
    parser = argparse.ArgumentParser(description='Precompute SpLiCE embeddings for datasets')
    parser.add_argument('--dataset_root', type=str, default='/software/ais2t/datasets/pytorch',
                      help='Root directory for datasets')
    parser.add_argument('--save_dir', type=str, default='../../SCRATCH/datasets',
                      help='Directory to save embeddings')
    parser.add_argument('--batch_size', type=int, default=64,
                      help='Batch size for computing embeddings')
    parser.add_argument('--num_batches', type=int, default=None,
                      help='Number of batches to process. If None, process entire dataset')
    parser.add_argument('--l1_penalties', nargs='+', type=float, default=[0.20],
                      help='L1 penalty values to use. Can specify multiple values')
    parser.add_argument('--dataset_name', type=str, default='cifar100', 
                      choices=['cifar100', 'imagenet', 'cocologic7', 'cocologic8', 'cocologic10'],
                      help='Name of dataset to use (cifar100 or imagenet)')
    parser.add_argument('--vocab_size', type=int, default=10000,
                      help='Vocabulary size for SpLiCE model')
    parser.add_argument('--nonsparse', action='store_true',
                      help='Generate non-sparse embeddings (uses local splice wrapper)')
    parser.add_argument('--use_github_splice', action='store_true',
                      help='Use GitHub SpLiCE instead of local wrapper (produces sparse embeddings)')
    args = parser.parse_args()
    
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")
    
    # Determine which splice implementation to use
    if args.use_github_splice:
        try:
            # Completely clear splice from sys.modules to force a fresh import
            for mod_name in list(sys.modules.keys()):
                if mod_name == 'splice' or mod_name.startswith('splice.'):
                    del sys.modules[mod_name]
            
            # Temporarily remove our custom paths
            saved_paths = []
            for path in list(sys.path):
                if 'splice_customized' in path or 'MerlinArthur-SpLiCE' in path:
                    sys.path.remove(path)
                    saved_paths.append(path)
                    
            # Try to locate and import system splice directly
            import site
            sys_site_packages = site.getsitepackages()
            for site_path in sys_site_packages:
                if site_path not in sys.path:
                    sys.path.insert(0, site_path)
                    
            # Now import GitHub splice
            import importlib
            splice = importlib.import_module('splice')
            print(f"Using GitHub SpLiCE from: {splice.__file__}")
            is_sparse = True
        
        except ImportError:
            print("Could not import GitHub SpLiCE. Make sure it's installed with pip.")
            print("Falling back to local splice wrapper.")
            
            # Restore saved paths
            for path in saved_paths:
                if path not in sys.path:
                    sys.path.append(path)
                    
            # Use local splice
            splice = splice_nonsparse
            is_sparse = False
    else:
        # Using our local custom wrapper (nonsparse)
        splice = splice_nonsparse
        is_sparse = False
        print(f"Using local SpLiCE wrapper from: {splice_nonsparse.__file__}")
    
    vocab_size = args.vocab_size
    print(f"Using vocabulary size: {vocab_size}")
    
    if args.nonsparse or not args.use_github_splice:
        # Nonsparse mode - just run once
        print(f"\nProcessing {args.dataset_name} with vocabulary size {vocab_size} (nonsparse)")
        
        # Load model with local splice wrapper
        splice_model = splice.load("open_clip:ViT-B-32",
                                vocabulary="laion",
                                vocabulary_size=vocab_size,
                                l1_penalty=0.0,  # Value doesn't matter for nonsparse
                                return_weights=False,
                                device=device)
        
        preprocess = splice.get_preprocess("open_clip:ViT-B-32")
        
        precompute_embeddings(splice_model, preprocess, device, 
                            args.dataset_root, args.save_dir,
                            args.dataset_name, vocab_size, sparse=False, 
                            batch_size=args.batch_size, num_batches=args.num_batches)
    else:
        # Sparse mode with GitHub splice - process each l1 penalty
        for l1_penalty in args.l1_penalties:
            print(f"\nProcessing {args.dataset_name} with vocabulary size {vocab_size}, L1 penalty: {l1_penalty:.3f}")
            
            splice_model = splice.load("open_clip:ViT-B-32",
                                    vocabulary="laion",
                                    vocabulary_size=vocab_size,
                                    l1_penalty=l1_penalty,
                                    return_weights=True,
                                    device=device)
            
            # if hasattr(splice_model, 'return_weights'):
            #     splice_model.return_weights = False  # Important for correct similarity scores
            
            preprocess = splice.get_preprocess("open_clip:ViT-B-32")
            
            precompute_embeddings(splice_model, preprocess, device, 
                                args.dataset_root, args.save_dir,
                                args.dataset_name, vocab_size, sparse=True,
                                l1_penalty=l1_penalty, batch_size=args.batch_size, 
                                num_batches=args.num_batches)

if __name__ == "__main__":
    main()