import argparse
import logging
import os
import random
from pathlib import Path

import h5py
import numpy as np
import torch
from torch.nn import functional as F
from tqdm import tqdm

from data.NICO import NICODataset

worst_group_contexts = {
    0: {"airplane": ["on beach", "at night"]},
    1: {"bear": ["on snow", "in water"]},
    2: {"bicycle": ["on beach", "on snow"]},
    3: {"bird": ["in hand", "on shoulder"]},
    4: {"boat": ["in city", "cross bridge"]},
    5: {"bus": ["on snow", "at yard"]},
    6: {"car": ["on beach", "on snow"]},
    7: {"cat": ["in river", "on snow"]},
    8: {"cow": ["in river", "on snow"]},
    9: {"dog": ["in water", "on snow"]},
    10: {"elephant": ["in street", "on snow"]},
    11: {"helicopter": ["on sea", "on snow"]},
    12: {"horse": ["in river", "on beach"]},
    13: {"monkey": ["on beach", "on snow"]},
    14: {"motorcycle": ["on snow", "on beach"]},
    15: {"rat": ["in water", "on snow"]},
    16: {"sheep": ["in water", "on road"]},
    17: {"train": ["on beach", "on snow"]},
    18: {"truck": ["on beach", "on snow"]}
}

# Configure global logger
logger = logging.getLogger(__name__)

def set_seed(seed):
    """Set global random seed for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

def get_parser_info():
    """Create and return the argument parser with command-line options."""
    parser = argparse.ArgumentParser(
        description="Compute prediction accuracy for CLIP model on Waterbirds dataset.",
        add_help=True
    )
    # Model parameters
    parser.add_argument(
        "--model",
        default="ViT-B-16",
        type=str,
        metavar="MODEL",
        help="Name of the CLIP model to use (default: ViT-B-16)"
    )
    # Dataset parameters
    parser.add_argument(
        "--batch_size",
        default=1000,
        type=int,
        help="Batch size for processing (adjusted dynamically based on GPU memory)"
    )
    parser.add_argument(
        "--input_dir",
        default="./results",
        help="Path where input data is saved"
    )
    parser.add_argument(
        "--dataset",
        type=str,
        default="nico",
        help="Dataset to process (default: nico)"
    )
    parser.add_argument(
        "--output_dir",
        default="./output/NICO",
        help="Path where output data is saved"
    )
    parser.add_argument(
        "--text_mode",
        default="openai",
        help="Text mode: 'simple' or 'openai' (default: openai)"
    )
    parser.add_argument(
        "--embedding_method",
        default="text_based_decomposition",
        help="Method to load or generate embeddings (e.g., 'clip_base'). Default: clip_base" 
    )
    parser.add_argument(
        "--cuda_id",
        default="0",
        help="CUDA ID to use (default: 0)"
    )
    return parser

def get_image_embeddings(embedding_method, h5_file, start_idx, end_idx, device):
    """
    Load image embeddings from an HDF5 file for a specific batch.

    Args:
        embedding_method (str): Method to load or generate embeddings (e.g., 'clip_base').
        h5_file (h5py.File): Opened HDF5 file containing image embeddings.
        start_idx (int): Starting index of the batch.
        end_idx (int): Ending index of the batch.
        device (torch.device): Device to load tensors to.

    Returns:
        torch.Tensor: Image embeddings for the specified batch.
    """
    try:
        import extract_image_embedding
        embedding_func = getattr(extract_image_embedding, f"{embedding_method}_embedding")
        logger.info(f"Computing in {embedding_method} method")
        return embedding_func(h5_file, start_idx, end_idx, device, args)
    except ImportError as e:
        raise ImportError(f"Failed to import embedding function for method '{embedding_method}': {e}")

def compute_accuracy(dataset, input_dir, batch_size, device, text_embeddings, embedding_method):
    """
    Compute overall and class-wise accuracy for the NICO dataset using CLIP embeddings.
    Also compute accuracy for each context within each class.
    
    Args:
        dataset (str): Name of the dataset to process
        input_dir (str): Directory containing the dataset files
        batch_size (int): Number of samples to process in each batch
        device (torch.device): Device to load tensors to
        text_embeddings (torch.Tensor): Pre-computed text embeddings
        embedding_method (str): Method to extract image embeddings
        
    Returns:
        dict: Dictionary containing accuracy metrics for overall, class-wise and context-wise analysis
    """
    subfolder_name = f"{args.model}_{dataset}"
    input_dir = os.path.join(input_dir, subfolder_name)

    if not os.path.exists(input_dir):
            raise FileNotFoundError(f"Input directory '{input_dir}' does not exist. Run 'extract_clip_info.py' first.")

    # Import the NICO dictionary from class_text module
    try:
        from class_text import NICO
        
        # Create a mapping from context index to context name
        context_idx_to_name = {}
        for class_idx, class_info in NICO.items():
            class_name = list(class_info.keys())[0]  # Get class name
            contexts = class_info[class_name]        # Get all contexts for this class
            
            context_idx_to_name[class_idx] = {}
            for ctx_idx, ctx_name in enumerate(contexts):
                context_idx_to_name[class_idx][ctx_idx] = ctx_name
                
        logger.info(f"Loaded context mapping from class_text.NICO dictionary")
    except Exception as e:
        logger.warning(f"Failed to load context mapping from class_text.NICO: {e}")
        context_idx_to_name = None

    with h5py.File(os.path.join(f"./results/Text_Based_Decomposition/{subfolder_name}/data.h5"), 'r') as f:
        labels_info_dset = f['labels_info']
        total_samples = labels_info_dset.shape[0]
        logger.info(f"Total samples in {dataset}: {total_samples}")

        # Get the number of classes in NICO dataset
        num_classes = len(text_embeddings)
        logger.info(f"Number of classes: {num_classes}")

        # Initialize correct prediction counts and total sample counts for each class
        class_correct_counts = {i: 0 for i in range(num_classes)}
        class_total_counts = {i: 0 for i in range(num_classes)}
        
        # Initialize counters for each context within each class
        context_correct_counts = {}
        context_total_counts = {}
        
        for class_idx in range(num_classes):
            context_correct_counts[class_idx] = {}
            context_total_counts[class_idx] = {}

        for start_idx in tqdm(range(0, total_samples, batch_size), desc=f"Processing {dataset}"):
            end_idx = min(start_idx + batch_size, total_samples)
            try:
                # Load image embeddings and labels for current batch
                if embedding_method == "text_based_decomposition":
                    with h5py.File(f"./results/Text_Based_Decomposition/{subfolder_name}/data.h5", 'r') as fm:
                        image_embeddings_batch = get_image_embeddings(embedding_method,fm, start_idx, end_idx, device)
                else:
                    image_embeddings_batch = get_image_embeddings(embedding_method,f, start_idx, end_idx, device)
                    
                # Get class labels and context labels
                labels_batch = labels_info_dset[start_idx:end_idx][:, 0]  # Class labels
                contexts_batch = labels_info_dset[start_idx:end_idx][:, 1]  # Context labels

                # Convert to PyTorch tensors and move to device
                image_embeddings_batch = image_embeddings_batch.to(device)
                labels_batch = torch.tensor(labels_batch, dtype=torch.long).to(device)
                contexts_batch = torch.tensor(contexts_batch, dtype=torch.long).to(device)

                # Compute predictions (using similarity with text embeddings)
                predictions = (100.0 * image_embeddings_batch @ text_embeddings.t()).argmax(dim=1)

                # Calculate accuracy for each class
                for class_idx in range(num_classes):
                    class_mask = (labels_batch == class_idx)
                    if class_mask.sum().item() > 0:  # Ensure there are samples for this class
                        correct_class = (predictions[class_mask] == labels_batch[class_mask])
                        correct_count = correct_class.sum().item()
                        class_correct_counts[class_idx] += correct_count
                        class_total_counts[class_idx] += class_mask.sum().item()
                        
                        # Calculate accuracy for each context within this class
                        for i, (label, context, correct) in enumerate(zip(labels_batch, contexts_batch, predictions == labels_batch)):
                            if label.item() == class_idx:
                                ctx_idx = context.item()
                                if ctx_idx not in context_correct_counts[class_idx]:
                                    context_correct_counts[class_idx][ctx_idx] = 0
                                    context_total_counts[class_idx][ctx_idx] = 0
                                
                                if correct.item():
                                    context_correct_counts[class_idx][ctx_idx] += 1
                                context_total_counts[class_idx][ctx_idx] += 1

                # Clean up memory
                del image_embeddings_batch, labels_batch, contexts_batch
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()

            except Exception as e:
                logger.error(f"Error in batch {start_idx}-{end_idx}: {e}")
                continue

        # Calculate overall accuracy and per-class accuracy
        total_correct = sum(class_correct_counts.values())
        overall_accuracy = (total_correct / total_samples) * 100 if total_samples > 0 else 0
        
        class_accuracies = {}
        for class_idx in range(num_classes):
            if class_total_counts.get(class_idx, 0) > 0:
                class_accuracies[class_idx] = (class_correct_counts[class_idx] / class_total_counts[class_idx]) * 100
            else:
                class_accuracies[class_idx] = 0
        
        # Calculate accuracy for each context
        context_accuracies = {}
        for class_idx in range(num_classes):
            context_accuracies[class_idx] = {}
            for ctx_idx in context_total_counts[class_idx]:
                if context_total_counts[class_idx][ctx_idx] > 0:
                    acc = (context_correct_counts[class_idx][ctx_idx] / context_total_counts[class_idx][ctx_idx]) * 100
                    context_accuracies[class_idx][ctx_idx] = acc
                else:
                    context_accuracies[class_idx][ctx_idx] = 0

        return {
            "overall_accuracy": overall_accuracy,
            "class_accuracies": class_accuracies,
            "context_accuracies": context_accuracies,
            "context_total_counts": context_total_counts,
            "context_correct_counts": context_correct_counts,
            "context_idx_to_name": context_idx_to_name
        }

def main(args):
    """
    Main function to compute and log accuracies for the NICO dataset.

    Args:
        args (argparse.Namespace): Command-line arguments parsed by the argument parser.
    """
    try:
        subfolder_name = f"{args.model}_{args.dataset}"
        input_dir = Path(args.input_dir) /  f"{subfolder_name}"
        output_dir = Path(args.output_dir)
        output_dir.mkdir(parents=True, exist_ok=True)

        if not input_dir.exists():
            raise FileNotFoundError(
                f"Input directory '{input_dir}' does not exist. Run 'extract_clip_info.py' first."
            )

        # Configure logging to file and console
        logging.basicConfig(
            level=logging.INFO,
            format="%(asctime)s - %(levelname)s - %(message)s",
            handlers=[
                logging.FileHandler(output_dir / f"Console_Info_{subfolder_name}.log", mode='w'),
                logging.StreamHandler()
            ]
        )

        device = torch.device(f"cuda:{args.cuda_id}" if torch.cuda.is_available() else "cpu")

        # Load text embeddings
        if args.text_mode in ["simple", "openai"]:
            logger.info(f"Text mode: {args.text_mode}")
            text_file_path = input_dir / f"{args.text_mode}_text.h5"
            with h5py.File(text_file_path, 'r') as f:
                text_embeddings = torch.tensor(f['text_embeddings'][:], dtype=torch.float32).to(device)
        else:
            raise ValueError(f"Text mode '{args.text_mode}' not supported.")
            # Compute accuracy
        results = compute_accuracy(
            dataset=args.dataset,
            input_dir=args.input_dir,
            batch_size=args.batch_size,
            device=device,
            text_embeddings=text_embeddings,
            embedding_method=args.embedding_method
        )

        # Log results
        logger.info(f"Dataset: {args.dataset}")
        logger.info(f"Overall accuracy: {results['overall_accuracy']:.2f}%")

        # Sort class results by accuracy in descending order
        sorted_classes = sorted(results['class_accuracies'].items(), key=lambda x: x[1], reverse=True)
        logger.info("Class-wise accuracies:")
        for class_idx, acc in sorted_classes:
            logger.info(f"Class {class_idx}: {acc:.2f}%")

        # Log accuracy for each context within each class
        logger.info("\nContext-wise accuracies for each class:")
        for class_idx in sorted(results['context_accuracies'].keys()):
            # Get class name
            class_name = ""
            try:
                if class_idx in NICO:
                    class_name = list(NICO[class_idx].keys())[0]
            except:
                pass
            
            if not class_name:
                class_name = f"Class {class_idx}"
            
            logger.info(f"\n{class_name} (overall: {results['class_accuracies'][class_idx]:.2f}%):")
            
            # Sort contexts by accuracy in descending order
            contexts = sorted(results['context_accuracies'][class_idx].items(), 
                                key=lambda x: x[1], reverse=True)
            
            for ctx_idx, acc in contexts:
                # Get context name
                ctx_name = f"Context {ctx_idx}"
                if results['context_idx_to_name'] is not None and class_idx in results['context_idx_to_name'] and ctx_idx in results['context_idx_to_name'][class_idx]:
                    ctx_name = results['context_idx_to_name'][class_idx][ctx_idx]
                
                samples = results['context_total_counts'][class_idx][ctx_idx]
                correct = results['context_correct_counts'][class_idx][ctx_idx]
                logger.info(f"  {ctx_name}: {acc:.2f}% ({correct}/{samples} samples)")

        # Save results to txt file
        results_file = output_dir / f"{subfolder_name}_{args.text_mode}_{args.embedding_method}_results.txt"
        with open(results_file, 'w', encoding='utf-8') as f:
            # Write overall accuracy
            f.write(f"Dataset: {args.dataset}\n")
            f.write(f"Overall accuracy: {results['overall_accuracy']:.2f}%\n\n")
            
            # Write per-class accuracies
            f.write("Class-wise accuracies:\n")
            for class_idx, acc in sorted_classes:
                f.write(f"Class {class_idx}: {acc:.2f}%\n")
            
            # Write per-context accuracies for each class
            f.write("\nContext-wise accuracies for each class:\n")
            for class_idx in sorted(results['context_accuracies'].keys()):
                # Get class name
                class_name = ""
                try:
                    from class_text import NICO
                    if class_idx in NICO:
                        class_name = list(NICO[class_idx].keys())[0]
                except:
                    pass
                
                if not class_name:
                    class_name = f"Class {class_idx}"
                    
                f.write(f"\n{class_name} (overall: {results['class_accuracies'][class_idx]:.2f}%):\n")
                
                # Sort contexts by accuracy in descending order
                contexts = sorted(results['context_accuracies'][class_idx].items(), 
                                    key=lambda x: x[1], reverse=True)
                
                for ctx_idx, acc in contexts:
                    # Get context name
                    ctx_name = f"Context {ctx_idx}"
                    if results['context_idx_to_name'] is not None and class_idx in results['context_idx_to_name'] and ctx_idx in results['context_idx_to_name'][class_idx]:
                        ctx_name = results['context_idx_to_name'][class_idx][ctx_idx]
                        
                    samples = results['context_total_counts'][class_idx][ctx_idx]
                    correct = results['context_correct_counts'][class_idx][ctx_idx]
                    f.write(f"  {ctx_name}: {acc:.2f}% ({correct}/{samples} samples)\n")

        logger.info(f"Results saved to {results_file}")

    except Exception as e:
        logger.exception(f"Execution failed: {e}")


if __name__ == "__main__":
    parser = get_parser_info()
    args = parser.parse_args()
    set_seed(42)
    main(args)