import argparse
import logging
import os
import random
from pathlib import Path

import h5py
import numpy as np
import torch
from torch.nn import functional as F
from tqdm import tqdm

# Configure global logger
logger = logging.getLogger(__name__)

def set_seed(seed):
    """Set global random seed for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

def get_parser_info():
    """Create and return the argument parser with command-line options."""
    parser = argparse.ArgumentParser(
        description="Compute prediction accuracy for CLIP model on Waterbirds dataset.",
        add_help=True
    )
    # Model parameters
    parser.add_argument(
        "--model",
        default="ViT-B-32",
        type=str,
        metavar="MODEL",
        help="Name of the CLIP model to use (default: ViT-B-32)"
    )
    # Dataset parameters
    parser.add_argument(
        "--batch_size",
        default=1000,
        type=int,
        help="Batch size for processing (adjusted dynamically based on GPU memory)"
    )
    parser.add_argument(
        "--input_dir",
        default="./results",
        help="Path where input data is saved"
    )
    parser.add_argument(
        "--dataset",
        type=str,
        default="waterbirds",
        help="Dataset to process (default: waterbirds)"
    )
    parser.add_argument(
        "--output_dir",
        default="./results/prs",
        help="Path where output data is saved"
    )
    parser.add_argument(
        "--text_mode",
        default="simple",
        help="Text mode: 'simple' or 'openai' (default: openai)"
    )
    parser.add_argument(
        "--embedding_method",
        default="clip_base",
        help="Method to load or generate embeddings (e.g., 'clip_base'). Default: clip_base" 
    )
    parser.add_argument(
        "--cuda_id",
        type=str,
        default="0",
        help="cuda id"
    )
    return parser

def get_image_embeddings(embedding_method, h5_file, start_idx, end_idx, device):
    """
    Load image embeddings from an HDF5 file for a specific batch.

    Args:
        embedding_method (str): Method to load or generate embeddings (e.g., 'clip_base').
        h5_file (h5py.File): Opened HDF5 file containing image embeddings.
        start_idx (int): Starting index of the batch.
        end_idx (int): Ending index of the batch.
        device (torch.device): Device to load tensors to.

    Returns:
        torch.Tensor: Image embeddings for the specified batch.
    """
    try:
        import extract_image_embedding
        embedding_func = getattr(extract_image_embedding, f"{embedding_method}_embedding")
        logger.info(f"Computing in {embedding_method} method")
        return embedding_func(h5_file, start_idx, end_idx, device, args)
    except ImportError as e:
        raise ImportError(f"Failed to import embedding function for method '{embedding_method}': {e}")

def compute_accuracy(dataset, input_dir, batch_size, device, text_embeddings, embedding_method):
    """
    Compute overall and group-wise accuracy for the Waterbirds dataset using CLIP embeddings.

    Args:
        dataset (str): Name of the dataset (e.g., 'waterbirds').
        input_dir (str): Directory containing the dataset's HDF5 file.
        batch_size (int): Number of samples per batch.
        device (torch.device): Device to perform computations on (CPU or GPU).
        text_embeddings (torch.Tensor): Precomputed text embeddings for classification.
        embedding_method(str): Method to load or generate embeddings (e.g., 'clip_base').

    Returns:
        dict: Dictionary with 'overall_accuracy' and 'group_accuracies' as percentages.
    """
    subfolder_name = f"{args.model}_{dataset}"
    input_dir = os.path.join(input_dir, subfolder_name)

    if not os.path.exists(input_dir):
            raise FileNotFoundError(f"Input directory '{input_dir}' does not exist. Run 'extract_clip_info.py' first.")

    with h5py.File(os.path.join(input_dir, "data.h5"), 'r') as f:
        labels_info_dset = f['labels_info']
        total_samples = labels_info_dset.shape[0]
        logger.info(f"Total samples in {dataset}: {total_samples}")

        # Initialize counters for each group (landbird+land, landbird+water, waterbird+land, waterbird+water)
        group_correct_counts = {i: 0 for i in range(4)}
        group_total_counts = {i: 0 for i in range(4)}

        for start_idx in tqdm(range(0, total_samples, batch_size), desc=f"Processing {dataset}"):
            end_idx = min(start_idx + batch_size, total_samples)
            try:
                # Load image embeddings and labels for the batch
                if embedding_method == "text_based_decomposition":
                    with h5py.File(f"./results/Text_Based_Decomposition/{subfolder_name}/data.h5", 'r') as fm:
                        image_embeddings_batch = get_image_embeddings(embedding_method,fm, start_idx, end_idx, device)
                else:
                    image_embeddings_batch = get_image_embeddings(embedding_method,f, start_idx, end_idx, device)
                    
                if dataset == "waterbirds":
                    labels_batch = labels_info_dset[start_idx:end_idx][:, 0]  # Class labels
                    places_batch = labels_info_dset[start_idx:end_idx][:, 1]  # Attribute labels
                else:
                    raise ValueError(f"Dataset '{dataset}' not supported yet.")

                # Convert to PyTorch tensors and move to device 
                image_embeddings_batch = image_embeddings_batch.to(device)
                labels_batch = torch.tensor(labels_batch, dtype=torch.long).to(device)
                places_batch = torch.tensor(places_batch, dtype=torch.long).to(device)

                # Compute predictions using similarity scores
                predictions = (100.0 * image_embeddings_batch @ text_embeddings.t()).argmax(dim=1)

                # Compute group-wise accuracy
                for group in range(4):
                    # Group logic: group = (label * 2 + place)
                    mask = (labels_batch == (group // 2)) & (places_batch == (group % 2))
                    correct_count = (predictions[mask] == labels_batch[mask]).sum().item()
                    group_correct_counts[group] += correct_count
                    group_total_counts[group] += mask.sum().item()

                # Free memory explicitly
                del image_embeddings_batch, labels_batch, places_batch
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()

            except Exception as e:
                logger.error(f"Error in batch {start_idx}-{end_idx}: {e}")
                continue

        # Calculate overall and group-wise accuracies
        overall_accuracy = (sum(group_correct_counts.values()) / total_samples) * 100 if total_samples > 0 else 0
        group_accuracies = {group: (group_correct_counts[group] / group_total_counts[group]) * 100
                                if group_total_counts[group] > 0 else 0 for group in range(4)}

        return {
            "overall_accuracy": overall_accuracy,
            "group_accuracies": group_accuracies
        }

def main(args):
    """
    Main function to compute and log accuracies for the Waterbirds dataset.

    Args:
        args (argparse.Namespace): Command-line arguments parsed by the argument parser.
    """
    try:
        subfolder_name = f"{args.model}_{args.dataset}"
        input_dir = Path(args.input_dir) / subfolder_name
        output_dir = Path(args.output_dir)
        output_dir.mkdir(parents=True, exist_ok=True)

        if not input_dir.exists():
            raise FileNotFoundError(
                f"Input directory '{input_dir}' does not exist. Run 'extract_clip_info.py' first."
            )

        # Configure logging to file and console
        logging.basicConfig(
            level=logging.INFO,
            format="%(asctime)s - %(levelname)s - %(message)s",
            handlers=[
                logging.FileHandler(output_dir / f"Console_Info_{subfolder_name}.log", mode='w'),
                logging.StreamHandler()
            ]
        )

        device = torch.device(f"cuda:{args.cuda_id}" if torch.cuda.is_available() else "cpu")

        # Load text embeddings
        if args.text_mode in ["simple", "openai"]:
            logger.info(f"Text mode: {args.text_mode}")
            text_file_path = input_dir / f"{args.text_mode}_text.h5"
            with h5py.File(text_file_path, 'r') as f:
                text_embeddings = torch.tensor(f['text_embeddings'][:], dtype=torch.float32).to(device)
        else:
            raise ValueError(f"Text mode '{args.text_mode}' not supported.")
        
        # Compute accuracy
        results = compute_accuracy(
            dataset=args.dataset,
            input_dir=args.input_dir,
            batch_size=args.batch_size,
            device=device,
            text_embeddings=text_embeddings,
            embedding_method=args.embedding_method
        )

        # Log results
        logger.info(f"Dataset: {args.dataset}")
        logger.info(f"Overall accuracy: {results['overall_accuracy']:.2f}%")
        
        # Group names for Waterbirds
        group_names = {
            0: "landbird+land",
            1: "landbird+water",
            2: "waterbird+land",
            3: "waterbird+water"
        }
        
        for group, acc in results['group_accuracies'].items():
            group_name = group_names.get(group, f"Group {group}")
            logger.info(f"{group_name} accuracy: {acc:.2f}%")

    except Exception as e:
        logger.exception(f"Execution failed: {e}")


if __name__ == "__main__":
    parser = get_parser_info()
    args = parser.parse_args()
    set_seed(42)
    main(args)