import argparse
import logging
import os
import random
from pathlib import Path
import h5py
import numpy as np
import torch
from torch.nn import functional as F

from classifier import classify_waterbirds, classify_urbancars, classify_cocogb, classify_imagenet_vs_imagenet_a, classify_imagenet_and_imagenet_w, classify_nico
import sys
project_root = os.path.abspath(os.getcwd())
sys.path.append(project_root)
from CLIP_utils.factory import create_model_and_transforms, get_tokenizer

# Configure global logger
logger = logging.getLogger(__name__)

def set_seed(seed):
    """Set global random seed for reproducibility.
    
    Args:
        seed (int): Random seed value.
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

def get_parser_info():
    """Create and return the argument parser with command-line options.
    
    Returns:
        argparse.ArgumentParser: Configured argument parser with all options.
    """
    parser = argparse.ArgumentParser(
        description="Compute prediction accuracy for CLIP model on Waterbirds dataset.",
        add_help=True
    )
    # Model parameters
    parser.add_argument("--model", default="ViT-B-32", type=str, 
            help="Name of the CLIP model to use (default: ViT-B-32)")
    # Dataset parameters
    parser.add_argument("--patch_size", default=32, type=int,
            help="Patch size for processing (default: 32)"
    )
    parser.add_argument("--batch_size", default=300, type=int,
            help="Batch size for processing (adjusted dynamically based on GPU memory)"
    )
    parser.add_argument("--input_dir",default="./results",
            help="Path where input data is saved"
    )
    
    parser.add_argument("--dataset", default="nico", type=str, 
            help="Dataset to process (default: waterbirds)")
    parser.add_argument("--cuda_id", default="0", type=str, help="CUDA ID")
    parser.add_argument("--lam", default=0.8, type=float, 
            help="Weight for TDE scores (default: 0.5)")
    parser.add_argument("--MAX_K", default=5, type=int, 
            help="Max number of classes considered (default: 5)")
    parser.add_argument("--lam_hat", default=0.7, type=float, 
            help="Weight for TDE scores (default: 1.0)")
    parser.add_argument("--alpha", default=0.7, type=float, 
            help="Weight for mixed background and foreground (default: 0.7)")
    parser.add_argument("--scene_type", default="virtual_cz", type=str, 
            help="Scene type (default: outer_cz, options: outer_cz, inner_cz, virtual_cz)")
    parser.add_argument("--select_scene_num", default=270, type=int, 
            help="Number of scenes to select (default: 100)")
    
    return parser

def main(args):
    """Main function to run the counterfactual CLIP analysis.
    
    This function loads data, initializes models, computes scene embeddings if needed,
    and runs the appropriate classification function based on the dataset.
    
    Args:
        args: Command-line arguments parsed by argparse.
    """
    try: 
        subfolder_name = f"{args.model}_{args.dataset}"
        input_dir = Path(args.input_dir) / subfolder_name
        Path(f"output/counterfactualCLIP/{args.dataset}/{args.model}").mkdir(parents=True, exist_ok=True)
        # Configure logging
        logging_format = "%(asctime)s [%(levelname)s] %(message)s"
        logging_datefmt = "%Y-%m-%d %H:%M"
        
        logging.basicConfig(level=logging.INFO, format=logging_format, datefmt=logging_datefmt)
        
        logger.info(f"Starting processing for model: {args.model} on dataset: {args.dataset}")
        logger.info(f"Input directory: {input_dir}")

        os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512'
        device = torch.device(f"cuda:{args.cuda_id}" if torch.cuda.is_available() else "cpu")
        logger.info(f"Using device: {device}")
        
        # Load input data
        f_data = h5py.File(input_dir / "data.h5", "r")
        f_text = h5py.File(input_dir / "openai_text.h5", "r")         
        total_samples = f_data["labels_info"].shape[0]
        logger.info(f"Total samples: {total_samples}")
        
        if args.dataset in ["imagenet", "imagenet_a", "imagenet_w"]:
            f_places = h5py.File(f"{args.input_dir}/places_embeddings/{args.model}_imagenet_places_embeddings.h5", "r")
        elif args.dataset in ["cocogbv1", "cocogbv2"]:
            f_places = h5py.File(f"{args.input_dir}/places_embeddings/{args.model}_cocogb_places_embeddings.h5", "r")
        elif args.dataset in ["waterbirds","urbancars"]:
            f_places = h5py.File(f"{args.input_dir}/places_embeddings/{args.model}_{args.dataset}_places_embeddings.h5", "r")
        elif args.dataset == "nico":
            f_places = h5py.File(f"{args.input_dir}/places_embeddings/{args.model}_{args.dataset}_places_embeddings.h5", "r")
        else:
            raise ValueError(f"Unsupported dataset: {args.dataset}")
        logger.info(f"Places embeddings loaded for {args.dataset}")

        text_embeddings = torch.from_numpy(f_text["text_embeddings"][:]).to(device)
        logger.info(f"Number of classes: {text_embeddings.shape[0]}")
        
        if args.scene_type == "virtual_cz":
            import virtual_scene
            if args.dataset in ["imagenet", "imagenet_a", "imagenet_w"]:
                scene_text = virtual_scene.SCENE_LIST["imagenet"]
                load_name = "imagenet"
            elif args.dataset in ["cocogbv1", "cocogbv2"]:
                scene_text = virtual_scene.SCENE_LIST["cocogb"]
                load_name = "cocogb"
            elif args.dataset == "nico":
                scene_text = virtual_scene.SCENE_LIST["nico"]
                load_name = "nico"
            else:
                scene_text = virtual_scene.SCENE_LIST[args.dataset]
                load_name = args.dataset
            # Define the file path for scene embeddings
            scene_embeddings_path = f"{args.input_dir}/scene_embeddings/{args.model}_{load_name}_scene_embeddings.h5"
            os.makedirs(os.path.dirname(scene_embeddings_path), exist_ok=True)
            
            # Check if pre-computed scene embeddings file exists
            if os.path.exists(scene_embeddings_path):
                logger.info(f"Loading existing scene embeddings from {scene_embeddings_path}")
                with h5py.File(scene_embeddings_path, "r") as f_scene:
                    scene_embeddings = torch.from_numpy(f_scene["scene_embeddings"][:]).to(device)
            else:
                logger.info(f"Computing scene embeddings for {args.dataset}")
                if args.model == "ViT-B-32":
                    model, _ ,_= create_model_and_transforms(args.model, pretrained="laion2b_s34b_b79k")
                elif args.model == "ViT-B-16":
                    model, _ ,_= create_model_and_transforms(args.model, pretrained="laion2b_s34b_b88k")
                elif args.model == "ViT-L-14":
                    model, _ ,_= create_model_and_transforms(args.model, pretrained="laion2b_s32b_b82k")
                elif args.model == "ViT-H-14":
                    model, _ ,_= create_model_and_transforms(args.model, pretrained="laion2b_s32b_b79k")
                else:
                    raise ValueError(f"Unsupported model: {args.model}")    
                model.to(device)
                model.eval()
                tokenizer = get_tokenizer(args.model)
                
                # Batch processing: Process scene names in batches and save embeddings
                text_batch_size = 200  # Batch size
                num_scenes = len(scene_text)
                
                # Create h5py file
                with h5py.File(scene_embeddings_path, "w") as f_scene:
                    all_embeddings = []     
                    # Process in batches
                    for i in range(0, num_scenes, text_batch_size):
                        start_idx = i
                        end_idx = min(i + text_batch_size, num_scenes)
                        current_batch = scene_text[start_idx:end_idx]
                        
                        # Encode the current batch
                        with torch.no_grad():
                            batch_tokenized = tokenizer(current_batch).to(device)
                            batch_embeddings = model.encode_text(batch_tokenized)
                            batch_embeddings = F.normalize(batch_embeddings, dim=-1)
                        all_embeddings.append(batch_embeddings)
                        torch.cuda.empty_cache()
                    
                    # Merge embeddings from all batches
                    scene_embeddings = torch.cat(all_embeddings, dim=0)
                    
                    # Save only the embeddings
                    f_scene.create_dataset("scene_embeddings", data=scene_embeddings.cpu().numpy())
                
                logger.info(f"Saved scene embeddings to {scene_embeddings_path}")
            
            args.virtual_scene_embeddings = scene_embeddings
            
        if args.dataset == "waterbirds":
            results = classify_waterbirds(f_data, text_embeddings, f_places, args, device, logger)
        elif args.dataset == "urbancars":
            results = classify_urbancars(f_data, text_embeddings, f_places, args, device, logger)
        elif args.dataset in ["cocogbv1", "cocogbv2"]:
            results = classify_cocogb(f_data, text_embeddings, f_places, args, device, logger)
        elif args.dataset == "imagenet_a":
            f_data_int = h5py.File(f"{args.input_dir}/{args.model}_imagenet/data.h5", "r")
            results = classify_imagenet_vs_imagenet_a(f_data_int,f_data, text_embeddings, f_places, args, device, logger)
        elif args.dataset == "imagenet_w":
            f_data_int = h5py.File(f"{args.input_dir}/{args.model}_imagenet/data.h5", "r")
            results = classify_imagenet_and_imagenet_w(f_data_int,f_data, text_embeddings, f_places, args, device, logger)
        elif args.dataset == "nico":
            results = classify_nico(f_data, text_embeddings, f_places, args, device, logger)
        
    except Exception as e:
        logger.error(f"Error during initialization: {e}", exc_info=True)
        raise e

if __name__ == "__main__":
    parser = get_parser_info()
    args = parser.parse_args()
    set_seed(42)
    main(args)