import argparse
import logging
import os
import random
from pathlib import Path
import einops
import h5py
import numpy as np
import torch
from torch.nn import functional as F

from tqdm import tqdm


# Configure global logger
logger = logging.getLogger(__name__)

def set_seed(seed):
    """Set global random seed for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

def get_parser_info():
    """Create and return the argument parser with command-line options."""
    parser = argparse.ArgumentParser(
        description="Compute prediction accuracy for CLIP model on ImageNet datasets in different ways.",
        add_help=True
    )
    # Model parameters
    parser.add_argument(
        "--model",
        default="ViT-B-32",
        type=str,
        metavar="MODEL",
        help="Name of the CLIP model to use (default: ViT-B-32)"
    )

    # Dataset parameters
    parser.add_argument(
        "--batch_size",
        default=1000,
        type=int,
        help="Batch size for processing (adjust based on GPU memory)"
    )
    parser.add_argument(
        "--input_dir",
        default="./results",
        help="Path where input data is saved"
    )
    parser.add_argument(
        "--dataset",
        type=str,
        default="waterbirds",
        help="Dataset to process (default: imagenet)"
    )
    parser.add_argument(
        "--output_dir",
        default="./results/prs",
        help="Path where output data is saved"
    )
    parser.add_argument(
        "--text_mode",
        default="simple",
        help="Text mode: 'simple' or 'openai' (default: simple)"
    )
    parser.add_argument(
        "--cuda_id",
        default=0,
        type=int,
        help="CUDA ID to use for processing (default: 0)"
    )
    parser.add_argument(
            "--ablation_method",
            default="cls",
            type=str,
            help="Ablation method: 'cls', 'mlps', 'none' ,'cls_mlps', 'msas' (default: none)"
    )
    return parser


def get_image_embeddings(h5_file, start_idx, end_idx, device, args):
    """Extract and process image embeddings with specified ablation method.
    
    Args:
        h5_file: HDF5 file containing attention and MLP data
        start_idx: Starting index for batch processing
        end_idx: Ending index for batch processing
        device: PyTorch device for tensor operations
        args: Command-line arguments containing ablation configurations
        
    Returns:
        Processed image embeddings as a tensor on the specified device
    """
    attentions = h5_file["attentions"][start_idx:end_idx]
    mlps = h5_file["mlps"][start_idx:end_idx]

    if args.ablation_method == "cls":
        # Ablate class token by replacing with mean
        cls = attentions[:,0,:] + mlps[:,0,:]
        attentions_no_cls = attentions[:,1:,:].sum(axis=1) + cls.mean(axis=0)[np.newaxis,:].repeat(attentions.shape[0], axis=0)
        M_image = (
                attentions_no_cls
                + mlps[:,1:,:].sum(axis=1)
        )
    elif args.ablation_method == "mlps":
        # Ablate MLP layers by replacing with mean
        mlps_mean = einops.repeat(mlps[:,1:,:].mean(axis=0), "l d -> b l d", b=attentions.shape[0])
        M_image = (
                attentions.sum(axis=1)
                + mlps_mean.sum(axis=1) + mlps[:,0,:]
        )
    elif args.ablation_method == "cls_mlps":
        # Ablate both class token and MLP layers
        attentions_no_cls = attentions[:,1:,:].sum(axis=1)
        mlps_mean = einops.repeat(mlps.mean(axis=0), "l d -> b l d", b=attentions.shape[0])
        M_image = (
                attentions_no_cls + attentions[:,0,:].mean(axis=0)[np.newaxis,:]
                + mlps_mean.sum(axis=1)
        )
    elif args.ablation_method == "msas":
        # Ablate multi-head self attention
        atts_mean = einops.repeat(attentions[:,1:,:].mean(axis=0), "n d -> b n d", b=attentions.shape[0])
        M_image = (
                atts_mean.sum(axis=1) + attentions[:,0,:]
                + mlps.sum(axis=1)
        )
    elif args.ablation_method == "none":
        # No ablation, use all components
        M_image = (
                attentions.sum(axis=1)
                + mlps.sum(axis=1)
        )
    return torch.tensor(M_image, dtype=torch.float32).to(device)

def compute_accuracy(dataset, input_dir, batch_size, device, text_embeddings, args):
    """Compute prediction accuracy for a dataset using image and text embeddings.
    
    Args:
        dataset: Name of the dataset to process
        input_dir: Directory containing the dataset files
        batch_size: Number of samples to process in each batch
        device: PyTorch device for tensor operations
        text_embeddings: Pre-computed text embeddings
        args: Command-line arguments
        
    Returns:
        Dictionary containing accuracy metrics
    """
    subfolder_name = f"{args.model}_{dataset}"
    input_dir = os.path.join(input_dir, subfolder_name)

    with h5py.File(os.path.join(input_dir, "data.h5"), 'r') as f:
        labels_info_dset = f['labels_info']
        total_samples = labels_info_dset.shape[0]
        logger.info(f"Total samples in {dataset}: {total_samples}")

        all_correct_counts = 0

        for start_idx in tqdm(range(0, total_samples, batch_size), desc=f"Processing {dataset}"):
            end_idx = min(start_idx + batch_size, total_samples)
            try:

                image_embeddings_batch = get_image_embeddings(f, start_idx, end_idx, device, args)
                if args.dataset in ["waterbirds", "cocogbv1", "urbancars"]:
                    labels_batch = labels_info_dset[start_idx:end_idx][:,0]
                elif args.dataset in ["imagenet", "imagenet_a"]:
                    labels_batch = labels_info_dset[start_idx:end_idx]
                else:
                    raise ValueError(f"Dataset '{args.dataset}' not supported.")

                image_embeddings_batch = image_embeddings_batch.to(device)
                labels_batch = torch.tensor(labels_batch, dtype=torch.long).to(device)

                #  Compute logits and predictions
                logits = 100.0 * image_embeddings_batch @ text_embeddings.t()
                predictions = logits.argmax(dim=1)  # Get predicted labels

                all_correct_counts += (predictions == labels_batch).sum().item()
                del image_embeddings_batch, labels_batch
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()

            except Exception as e:
                logger.error(f"Error in batch {start_idx}-{end_idx}: {e}")
                continue

        overall_accuracy = (all_correct_counts / total_samples) * 100.0 if total_samples > 0 else 0

        return {
            "overall_accuracy": overall_accuracy,
        }
            

def main(args):
    """Main function to run the prediction accuracy computation process.
    
    Args:
        args: Command-line arguments
    """
    try:
        output_dir = Path(args.input_dir) / "prs"
        output_dir.mkdir(parents=True, exist_ok=True)

        logging.basicConfig(
            level=logging.INFO,
            format="%(asctime)s - %(levelname)s - %(message)s",
        )

        device = torch.device(f"cuda:{args.cuda_id}" if torch.cuda.is_available() else "cpu")

        if args.text_mode in ["simple", "openai"]:
            logger.info(f"Text mode: {args.text_mode}")
            text_file_path = Path(args.input_dir) / f"{args.model}_{args.dataset}" / f"{args.text_mode}_text.h5"
            with h5py.File(text_file_path, 'r') as f:
                text_embeddings = torch.tensor(f['text_embeddings'][:], dtype=torch.float32).to(device)
        else:
            raise ValueError(f"Text mode '{args.text_mode}' not supported.")

        results = compute_accuracy(
            dataset=args.dataset,
            input_dir=args.input_dir,
            batch_size=args.batch_size,
            device=device,
            text_embeddings=text_embeddings,
            args=args
        )
        logger.info(f"Dataset: {args.dataset}")
        logger.info(f"Overall accuracy: {results['overall_accuracy']:.2f}%")
    except Exception as e:
        logger.exception(f"Execution failed: {e}")

if __name__ == "__main__":
    parser = get_parser_info()
    args = parser.parse_args()
    set_seed(42)
    main(args)