import argparse
import os
import random
import numpy as np
import torch
import h5py
from torch.nn import functional as F
from pathlib import Path
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from tqdm import tqdm
import logging
from data.NICO import NICODataset  # Import NICO dataset

from CLIP_utils.factory import create_model_and_transforms, get_tokenizer
from prs_hook import hook_prs_logger

# Set HF mirror URL to accelerate model downloads
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"

# Configure global logger
logger = logging.getLogger(__name__)

def set_seed(seed):
    """Set global random seed for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

def get_parser_info():
    """Configure command-line argument parser."""
    parser = argparse.ArgumentParser(
        description="Load and process dataset for CLIP zero-shot classification.",
        add_help=True
    )
    # Model parameters
    parser.add_argument("--model", default="ViT-B-32", type=str, help="Name of the model to use.")
    parser.add_argument("--pretrained", default="laion2b_s34b_b79k", type=str, help="Pretrained weights to load.")
    parser.add_argument("--batch_size", type=int, default=32, help="Batch size for data loading.")
    # Data parameters
    parser.add_argument("--num_workers", type=int, default=4, help="Number of workers for data loading.")
    parser.add_argument("--dataset", default="nico", type=str, help="Name of the dataset.")
    parser.add_argument("--data_root", type=str, default="./data", help="Path to the dataset root directory.")
    # Output parameters
    parser.add_argument("--output_dir", type=str, default="./results", help="Path to save output results.")
    parser.add_argument("--cuda_id", type=str, default="0", help="cuda id")
    parser.add_argument("--just_text", type=str, default=None, help="Choose mode: simple/openai/None")
    # Option to save label information
    parser.add_argument("--save_labels", type=bool, default=False, help="Whether to save dataset label information to txt file")
    return parser

def _create_dataloaders(args, preprocess):
    """Create a DataLoader for the specified dataset.

    Args:
        args: Command-line arguments containing dataset and loader settings.
        preprocess: Image preprocessing pipeline.

    Returns:
        Tuple of DataLoader instance and total number of samples.
    """
    dataset_map = {
        "nico": lambda: NICODataset(root=os.path.join(args.data_root, "NICO"), transform=preprocess), 
    }
    data_path = os.path.join(args.data_root, args.dataset)
    # if not os.path.exists(data_path):
    #     raise FileNotFoundError(f"Dataset path not found: {data_path}")
    ds = dataset_map.get(args.dataset, lambda: ImageFolder(root=data_path, transform=preprocess))()
    
    # If NICO dataset and need to save label information
    if args.dataset.lower() == "nico" and args.save_labels and hasattr(ds, 'save_label_info'):
        labels_file = os.path.join(args.output_dir, f"{args.dataset}_labels.txt")
        os.makedirs(os.path.dirname(labels_file), exist_ok=True)
        ds.save_label_info(labels_file)
        logger.info(f"NICO dataset label information saved to: {labels_file}")
    
    dataloader = DataLoader(
        ds, batch_size=args.batch_size, shuffle=True,
        num_workers=args.num_workers, pin_memory=True, prefetch_factor=2
    )
    return dataloader, len(ds)

def extract_class_names_from_dict(NICO):
    """Extract class names list from NICO dictionary.
    
    Args:
        NICO: Dictionary containing class information.
        
    Returns:
        list: List of class names.
    """
    class_names = []
    for idx in sorted(NICO.keys()):
        # Each index corresponds to a dictionary with one key
        class_name = list(NICO[idx].keys())[0]
        class_names.append(class_name)
    return class_names

def extract_contexts_from_dict(NICO):
    """Extract context/environment information from NICO dictionary.
    
    Args:
        NICO: Dictionary containing class and context information.
        
    Returns:
        dict: Dictionary of contexts organized by class {class_name: [context_list]}.
    """
    contexts_by_class = {}
    for idx in sorted(NICO.keys()):
        # Get class name
        class_name = list(NICO[idx].keys())[0]
        # Get all contexts for this class
        contexts = NICO[idx][class_name]
        contexts_by_class[class_name] = contexts
    return contexts_by_class

def _create_text_embedding(args, model, device):
    """Generate text embeddings for zero-shot classification.

    Args:
        args: Command-line arguments containing model and dataset settings.
        model: CLIP model instance.
        device: Device to perform computations on (CPU or GPU).

    Returns:
        torch.Tensor: Text embeddings for all classes.
    """
    import class_text  # Assumes this module defines dataset class names
    if args.just_text == "simple":
        from CLIP_utils.template.simple_templates import OPENAI_IMAGENET_TEMPLATES
    elif args.just_text == "openai":
        from CLIP_utils.template.openai_templates import OPENAI_IMAGENET_TEMPLATES
    else:
        raise ValueError("Invalid mode for text embedding generation.")

    class_names_dict = getattr(class_text, "NICO", None)
    
    class_names = extract_class_names_from_dict(class_names_dict)
    contexts_dict = extract_contexts_from_dict(class_names_dict)

    tokenizer = get_tokenizer(args.model)
    
    # Complete text embedding calculation
    batch_size = min(len(class_names), 50)
    all_text_embeddings = []

    for i in tqdm(range(0, len(class_names), batch_size)):
        class_names_batch = class_names[i:i + batch_size]
        batch_texts = [
            template(classname) if callable(template) else template.format(classname)
            for classname in class_names_batch
            for template in OPENAI_IMAGENET_TEMPLATES
        ]
        try:
            batch_tokens = tokenizer(batch_texts).to(device)
            with torch.no_grad():
                batch_embeddings = model.encode_text(batch_tokens)
                batch_embeddings = F.normalize(batch_embeddings, dim=-1)
                batch_embeddings = batch_embeddings.view(len(class_names_batch), -1, batch_embeddings.size(-1))
                batch_embeddings = batch_embeddings.mean(dim=1)
                batch_embeddings = F.normalize(batch_embeddings, dim=-1)
                all_text_embeddings.append(batch_embeddings)
        except RuntimeError as e:
            logger.error(f"Error generating text embeddings for batch {i}: {e}")
            raise
        finally:
            del batch_tokens, batch_embeddings
            torch.cuda.empty_cache()

    text_embeddings = torch.cat(all_text_embeddings, dim=0)
    
    logger.info(
        f"Text embeddings generated for {len(class_names)} classes using {len(OPENAI_IMAGENET_TEMPLATES)} {args.just_text} templates.")
    return text_embeddings.to(device)

def _save_to_hdf5(f, start_idx, cumulative_attentions, cumulative_mlps,
                cumulative_image_embeddings, cumulative_labels_info, total_samples):
    """Save cumulative batch data to HDF5 file.

    Args:
        f: HDF5 file handle.
        start_idx: Starting index for saving data.
        cumulative_attentions: Cumulative attention data.
        cumulative_mlps: Cumulative MLP data.
        cumulative_image_embeddings: Cumulative image embeddings.
        cumulative_labels_info: Cumulative labels_info.
        total_samples: Total number of samples in the dataset.

    Returns:
        Updated starting index after saving.
    """
    if start_idx == 0:
        attentions_shape = (total_samples,) + cumulative_attentions.shape[1:]
        mlps_shape = (total_samples,) + cumulative_mlps.shape[1:]
        image_embeddings_shape = (total_samples,) + cumulative_image_embeddings.shape[1:]
        labels_info_shape = (total_samples,) + cumulative_labels_info.shape[1:]
        f.create_dataset("attentions", shape=attentions_shape, dtype='float32', compression="gzip")
        f.create_dataset("mlps", shape=mlps_shape, dtype='float32', compression="gzip")
        f.create_dataset("image_embeddings", shape=image_embeddings_shape, dtype='float32', compression="gzip")
        f.create_dataset("labels_info", shape=labels_info_shape, dtype='int64', compression="gzip")

    f["attentions"][start_idx:start_idx + len(cumulative_attentions)] = cumulative_attentions
    f["mlps"][start_idx:start_idx + len(cumulative_mlps)] = cumulative_mlps
    f["image_embeddings"][start_idx:start_idx + len(cumulative_image_embeddings)] = cumulative_image_embeddings
    f["labels_info"][start_idx:start_idx + len(cumulative_labels_info)] = cumulative_labels_info
    return start_idx + len(cumulative_labels_info)

def modify_preprocess_with_watermark(preprocess, watermark_transform):
    """Modify preprocess pipeline to include watermark transformation.

    Args:
        preprocess: Original preprocessing pipeline.
        watermark_transform: Watermark transformation to insert.

    Returns:
        Modified preprocessing pipeline.
    """
    from torchvision.transforms import Compose, ToTensor
    if not isinstance(preprocess, Compose):
        raise ValueError("preprocess is not a Compose object")
    transforms = preprocess.transforms
    for i, transform in enumerate(transforms):
        if isinstance(transform, ToTensor):
            break
    else:
        raise ValueError("ToTensor not found in preprocess transforms")
    new_transforms = transforms[:i+1] + [watermark_transform] + transforms[i+1:]
    return Compose(new_transforms)

def main(args):
    """Main function to process dataset and perform zero-shot classification.

    Args:
        args: Command-line arguments.
    """
    try: 
        subfolder_name = f"{args.model}_{args.dataset}"
        output_dir = Path(args.output_dir) / subfolder_name
        output_dir.mkdir(parents=True, exist_ok=True)

        if not args.just_text:
            logging.basicConfig(
                level=logging.INFO,
                format="%(asctime)s - %(levelname)s - %(message)s",
                handlers=[
                    logging.FileHandler(output_dir / "Console_Info.log", mode='w'),
                    logging.StreamHandler()
                ]
            )
        else:
            logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")

        os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512'
        device = torch.device(f"cuda:{args.cuda_id}")

        model, _, preprocess = create_model_and_transforms(args.model, pretrained=args.pretrained)
        model.to(device)
        model.eval()
        prs = hook_prs_logger(model, device)

        model_parameters = sum(p.numel() for p in model.parameters())
        logger.info(f"Model parameters: {model_parameters}")
        logger.info(f"Context length: {model.context_length}")
        logger.info(f"Vocabulary size: {model.vocab_size}")
        logger.info(f"Number of transformer residual blocks: {len(model.visual.transformer.resblocks)}")
        
        if args.just_text:
            text_embeddings = _create_text_embedding(args, model, device)
            with h5py.File(output_dir / f"{args.just_text}_text.h5", 'a') as f:
                if "text_embeddings" in f:
                    del f["text_embeddings"]
                f.create_dataset("text_embeddings", data=text_embeddings.detach().cpu().numpy(), dtype='float32', compression="gzip")
            logger.info(f"Text embeddings saved as {args.just_text}_text.h5")
            return
        
        data_loader, total_samples = _create_dataloaders(args, preprocess)
        with h5py.File(output_dir / "data.h5", 'w') as f:
            start_idx = 0
            cumulative_attentions = []
            cumulative_mlps = []
            cumulative_image_embeddings = []
            cumulative_labels_info = []
            save_every_n_batches = max(1, 50 // args.batch_size)
            with torch.inference_mode():
                for batch_idx, (images, labels_info) in enumerate(tqdm(data_loader)):
                    prs.reinit()
                    image_embeddings = model.encode_image(images.to(device), attn_method="head", normalize=False)
                    
                    attentions, mlps = prs.finalize(representation=image_embeddings)
                    attentions = attentions.sum(dim=(1, 3)).detach().cpu().numpy()
                    image_embeddings = F.normalize(image_embeddings, dim=-1).detach().cpu().numpy()
                    cumulative_attentions.append(attentions)
                    cumulative_mlps.append(mlps.detach().cpu().numpy())
                    cumulative_image_embeddings.append(image_embeddings)
                    cumulative_labels_info.append(labels_info.numpy())

                    if (batch_idx + 1) % save_every_n_batches == 0 or (batch_idx + 1) == len(data_loader):
                        cumulative_attentions_np = np.concatenate(cumulative_attentions, axis=0)
                        cumulative_mlps_np = np.concatenate(cumulative_mlps, axis=0)
                        cumulative_image_embeddings_np = np.concatenate(cumulative_image_embeddings, axis=0)
                        cumulative_labels_info_np = np.concatenate(cumulative_labels_info, axis=0)
                        start_idx = _save_to_hdf5(f, start_idx, cumulative_attentions_np, cumulative_mlps_np,
                                                cumulative_image_embeddings_np, cumulative_labels_info_np, total_samples)
                        cumulative_attentions = []
                        cumulative_mlps = []
                        cumulative_image_embeddings = []
                        cumulative_labels_info = []

                    del images, labels_info, attentions, mlps, image_embeddings
                    torch.cuda.empty_cache()

            logger.info(f"Saved attentions, mlps, image_embeddings, labels_info to {output_dir}")
        
        logger.info(f"Dataset '{args.dataset}' loaded successfully, total samples: {total_samples}")

    except Exception as e:
        logger.exception(f"Execution failed: {e}")
        # Print more path information for debugging
        if "No such file or directory" in str(e):
            data_path = os.path.join(args.data_root, args.dataset)
            logger.error(f"Attempted to access dataset path: {data_path}")
            logger.error(f"Current working directory: {os.getcwd()}")
            if os.path.exists(args.data_root):
                logger.error(f"Data root directory exists, contains: {os.listdir(args.data_root)}")

if __name__ == "__main__":
    parser = get_parser_info()
    args = parser.parse_args()
    set_seed(42)
    main(args)

# To generate NICO label file, use the following command line arguments:
# python extract_clip_info_nico.py --dataset nico --save_labels

