import argparse
import logging
import os
import random
from pathlib import Path

import h5py
import numpy as np
import torch
from torch.nn import functional as F
from tqdm import tqdm

# Configure global logger
logger = logging.getLogger(__name__)

def set_seed(seed):
    """Set global random seed for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

def get_parser_info():
    """Create and return the argument parser with command-line options."""
    parser = argparse.ArgumentParser(
        description="Compute prediction accuracy for CLIP model on Waterbirds dataset.",
        add_help=True
    )
    # Model parameters
    parser.add_argument(
        "--model",
        default="ViT-B-32",
        type=str,
        metavar="MODEL",
        help="Name of the CLIP model to use (default: ViT-B-32)"
    )
    # Dataset parameters
    parser.add_argument(
        "--batch_size",
        default=500,
        type=int,
        help="Batch size for processing (adjusted dynamically based on GPU memory)"
    )
    parser.add_argument(
        "--input_dir",
        default="./results",
        help="Path where input data is saved"
    )
    parser.add_argument(
        "--dataset",
        type=str,
        default="urbancars",
        help="Dataset to process (default: urbancars)"
    )
    parser.add_argument(
        "--output_dir",
        default="./results/prs",
        help="Path where output data is saved"
    )
    parser.add_argument(
        "--text_mode",
        default="openai",
        help="Text mode: 'simple' or 'openai' (default: openai)"
    )
    parser.add_argument(
        "--embedding_method",
        default="clip_base",
        help="Embedding method: 'clip_base' or 'direct_effect' or 'text_based_decomposition' (default: clip_base)"
    )
    parser.add_argument(
        "--cuda_id",
        type=str,
        default="0",
        help="cuda id"
    )
    return parser

def get_image_embeddings(embedding_method, h5_file, start_idx, end_idx, device):
    """
    Load image embeddings from an HDF5 file for a specific batch.

    Args:
        embedding_method (str): Method to load or generate embeddings (e.g., 'clip_base').
        h5_file (h5py.File): Opened HDF5 file containing image embeddings.
        start_idx (int): Starting index of the batch.
        end_idx (int): Ending index of the batch.
        device (torch.device): Device to load tensors to.

    Returns:
        torch.Tensor: Image embeddings for the specified batch.
    """
    try:
        import extract_image_embedding
        embedding_func = getattr(extract_image_embedding, f"{embedding_method}_embedding")
        logger.info(f"Computing in {embedding_method} method")
        return embedding_func(h5_file, start_idx, end_idx, device, args)
    except ImportError as e:
        raise ImportError(f"Failed to import embedding function for method '{embedding_method}': {e}")

def compute_accuracy(dataset, input_dir, batch_size, device, text_embeddings, embedding_method):
    """
    Compute overall and group-wise accuracy for the Urbancars dataset using CLIP embeddings.

    Args:
        dataset (str): Name of the dataset (e.g., 'Urbancars').
        input_dir (str): Directory containing the dataset's HDF5 file.
        batch_size (int): Number of samples per batch.
        device (torch.device): Device to perform computations on (CPU or GPU).
        text_embeddings (torch.Tensor): Precomputed text embeddings for classification.
        embedding_method(str): Method to load or generate embeddings (e.g., 'clip_base').

    Returns:
        dict: Dictionary with group accuracies and gaps between groups.
    """
    subfolder_name = f"{args.model}_{dataset}"
    input_dir = os.path.join(input_dir, subfolder_name)

    if not os.path.exists(input_dir):
        raise FileNotFoundError(f"Input directory '{input_dir}' does not exist. Run 'extract_clip_info.py' first.")

    with h5py.File(os.path.join(input_dir, "data.h5"), 'r') as f:
        labels_info_dset = f['labels_info']
        total_samples = labels_info_dset.shape[0]
        logger.info(f"Total samples in {dataset}: {total_samples}")

        # Initialize counters for the four groups
        id_correct, id_total = 0, 0
        bg_correct, bg_total = 0, 0
        coobj_correct, coobj_total = 0, 0
        bg_coobj_correct, bg_coobj_total = 0, 0

        for start_idx in tqdm(range(0, total_samples, batch_size), desc=f"Processing {dataset}"):
            end_idx = min(start_idx + batch_size, total_samples)
            try:
                # Load image embeddings and labels for the batch
                if embedding_method == "text_based_decomposition":
                    with h5py.File(f"./results/Text_Based_Decomposition/{subfolder_name}/data.h5", 'r') as fm:
                        image_embeddings_batch = get_image_embeddings(embedding_method,fm, start_idx, end_idx, device)
                else:
                    image_embeddings_batch = get_image_embeddings(embedding_method,f, start_idx, end_idx, device)
                if dataset == "urbancars":
                    # Get complete label information
                    labels_info_batch = labels_info_dset[start_idx:end_idx][:]
                    labels_batch = labels_info_batch[:, 0]  # Class labels
                    places_batch = labels_info_batch[:, 1]  # Attribute labels
                    coobj_batch = labels_info_batch[:, 2]
                else:
                    raise ValueError(f"Dataset '{dataset}' not supported yet.")

                # Convert to PyTorch tensors and move to device
                image_embeddings_batch = image_embeddings_batch.to(device)
                labels_batch = torch.tensor(labels_batch, dtype=torch.long).to(device)
                places_batch = torch.tensor(places_batch, dtype=torch.long).to(device)
                coobj_batch = torch.tensor(coobj_batch, dtype=torch.long).to(device)

                # Compute predictions using similarity scores
                predictions = (100.0 * image_embeddings_batch @ text_embeddings.t()).argmax(dim=1)
                
                # Calculate group-wise accuracy
                # ID group: [0,0,0] and [1,1,1]
                id_mask = ((labels_batch == 0) & (places_batch == 0) & (coobj_batch == 0)) | \
                        ((labels_batch == 1) & (places_batch == 1) & (coobj_batch == 1))
                id_correct += (predictions[id_mask] == labels_batch[id_mask]).sum().item()
                id_total += id_mask.sum().item()
                
                # BG group: [0,1,0] and [1,0,1]
                bg_mask = ((labels_batch == 0) & (places_batch == 1) & (coobj_batch == 0)) | \
                        ((labels_batch == 1) & (places_batch == 0) & (coobj_batch == 1))
                bg_correct += (predictions[bg_mask] == labels_batch[bg_mask]).sum().item()
                bg_total += bg_mask.sum().item()
                
                # CoObj group: [0,0,1] and [1,1,0]
                coobj_mask = ((labels_batch == 0) & (places_batch == 0) & (coobj_batch == 1)) | \
                            ((labels_batch == 1) & (places_batch == 1) & (coobj_batch == 0))
                coobj_correct += (predictions[coobj_mask] == labels_batch[coobj_mask]).sum().item()
                coobj_total += coobj_mask.sum().item()
                
                # BG_CoObj group: [1,0,0] and [0,1,1]
                bg_coobj_mask = ((labels_batch == 1) & (places_batch == 0) & (coobj_batch == 0)) | \
                                ((labels_batch == 0) & (places_batch == 1) & (coobj_batch == 1))
                bg_coobj_correct += (predictions[bg_coobj_mask] == labels_batch[bg_coobj_mask]).sum().item()
                bg_coobj_total += bg_coobj_mask.sum().item()

                # Free memory explicitly
                del image_embeddings_batch, labels_batch, places_batch, coobj_batch
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()

            except Exception as e:
                logger.error(f"Error in batch {start_idx}-{end_idx}: {e}")
                continue
        
        # Calculate group-wise accuracies
        id_accuracy = (id_correct / id_total) * 100 if id_total > 0 else 0
        bg_accuracy = (bg_correct / bg_total) * 100 if bg_total > 0 else 0
        coobj_accuracy = (coobj_correct / coobj_total) * 100 if coobj_total > 0 else 0
        bg_coobj_accuracy = (bg_coobj_correct / bg_coobj_total) * 100 if bg_coobj_total > 0 else 0
        
        # Calculate gaps (difference from ID group)
        bg_gap = id_accuracy - bg_accuracy
        coobj_gap = id_accuracy - coobj_accuracy
        bg_coobj_gap = id_accuracy - bg_coobj_accuracy

        return {
            "id_accuracy": id_accuracy,
            "bg_accuracy": bg_accuracy,
            "coobj_accuracy": coobj_accuracy,
            "bg_coobj_accuracy": bg_coobj_accuracy,
            "bg_gap": bg_gap,
            "coobj_gap": coobj_gap,
            "bg_coobj_gap": bg_coobj_gap
        }

def main(args):
    """
    Main function to compute and log accuracies for the Urbancars dataset.

    Args:
        args (argparse.Namespace): Command-line arguments parsed by the argument parser.
    """
    try:
        subfolder_name = f"{args.model}_{args.dataset}"
        input_dir = Path(args.input_dir) / subfolder_name
        output_dir = Path(args.output_dir)
        output_dir.mkdir(parents=True, exist_ok=True)

        if not input_dir.exists():
            raise FileNotFoundError(
                f"Input directory '{input_dir}' does not exist. Run 'extract_clip_info.py' first."
            )

        # Configure logging to file and console
        logging.basicConfig(
            level=logging.INFO,
            format="%(asctime)s - %(levelname)s - %(message)s",
            handlers=[
                logging.FileHandler(output_dir / f"Console_Info_{subfolder_name}.log", mode='w'),
                logging.StreamHandler()
            ]
        )

        device = torch.device(f"cuda:{args.cuda_id}" if torch.cuda.is_available() else "cpu")

        # Load text embeddings
        if args.text_mode in ["simple", "openai"]:
            logger.info(f"Text mode: {args.text_mode}")
            text_file_path = input_dir / f"{args.text_mode}_text.h5"
            with h5py.File(text_file_path, 'r') as f:
                text_embeddings = torch.tensor(f['text_embeddings'][:], dtype=torch.float32).to(device)
        else:
            raise ValueError(f"Text mode '{args.text_mode}' not supported.")
            
        # Compute accuracy
        results = compute_accuracy(
            dataset=args.dataset,
            input_dir=args.input_dir,
            batch_size=args.batch_size,
            device=device,
            text_embeddings=text_embeddings,
            embedding_method=args.embedding_method
        )

        # Output results
        logger.info(f"Dataset: {args.dataset}")
        
        logger.info("\n--- Group Accuracies ---")
        logger.info(f"ID Group Accuracy: {results['id_accuracy']:.2f}%")
        logger.info(f"BG Group Accuracy: {results['bg_accuracy']:.2f}% (GAP: {results['bg_gap']:.2f}%)")
        logger.info(f"CoObj Group Accuracy: {results['coobj_accuracy']:.2f}% (GAP: {results['coobj_gap']:.2f}%)")
        logger.info(f"BG_CoObj Group Accuracy: {results['bg_coobj_accuracy']:.2f}% (GAP: {results['bg_coobj_gap']:.2f}%)")

    except Exception as e:
        logger.exception(f"Execution failed: {e}")


if __name__ == "__main__":
    parser = get_parser_info()
    args = parser.parse_args()
    set_seed(42)
    main(args)