import argparse
import os
import random
import numpy as np
import torch
import h5py
from torch.nn import functional as F
from pathlib import Path
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder, ImageNet
from tqdm import tqdm
import logging
from data.waterbirds import WaterbirdsDataset
from data.imagenet_a import ImageNetA
from data.COCO_GB_V1 import COCO_GB_V1_dataset
# from data.COCO_GB_V2 import COCO_GB_V2_dataset
from data.urbancars import UrbancarsDataset
# from data.imagenet_w.watermark_transform import AddWatermark
from PIL import Image

from CLIP_utils.factory import create_model_and_transforms, get_tokenizer
from prs_hook import hook_prs_logger

# Set HF mirror URL to accelerate model downloads
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"

# Configure global logger
logger = logging.getLogger(__name__)

def set_seed(seed):
    """Set global random seed for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

def get_parser_info():
    """Configure command-line argument parser."""
    parser = argparse.ArgumentParser(
        description="Load and process dataset for CLIP zero-shot classification.",
        add_help=True
    )
    # Model parameters
    parser.add_argument("--model", default="ViT-B-32", type=str, help="Name of the model to use.")
    parser.add_argument("--pretrained", default="laion2b_s34b_b79k", type=str, help="Pretrained weights to load.")
    parser.add_argument("--batch_size", type=int, default=32, help="Batch size for data loading.")
    # Data parameters
    parser.add_argument("--num_workers", type=int, default=4, help="Number of workers for data loading.")
    parser.add_argument("--dataset", default="cocogbv1", type=str, help="Name of the dataset.")
    parser.add_argument("--data_root", type=str, default="./data", help="Path to the dataset root directory.")
    # Output parameters
    parser.add_argument("--output_dir", type=str, default="./results", help="Path to save output results.")
    parser.add_argument("--just_text", type=str, default=None, help="Choose mode: simple/openai/None")
    parser.add_argument("--cuda_id", type=str, default="0", help="cuda id")
    return parser

def _create_dataloaders(args, preprocess):
    """Create a DataLoader for the specified dataset.

    Args:
        args: Command-line arguments containing dataset and loader settings.
        preprocess: Image preprocessing pipeline.

    Returns:
        Tuple of DataLoader instance and total number of samples.
    """
    dataset_map = {
        # "imagenet": lambda: ImageNet(root=os.path.join(args.data_root, "imagenet"), split="val", transform=preprocess),
        # "imagenet_w": lambda: ImageNet(root=os.path.join(args.data_root, "imagenet"), split="val", transform=preprocess),
        # "imagenet_a": lambda: ImageNetA(root=os.path.join(args.data_root, "imagenet_a"), transform=preprocess),
        "waterbirds": lambda: WaterbirdsDataset(root=os.path.join(args.data_root, "waterbirds"), split="test", transform=preprocess),
        "cocogbv1": lambda: COCO_GB_V1_dataset(root=args.data_root, split='test', transform=preprocess),
        # "cocogbv2": lambda: COCO_GB_V2_dataset(root=args.data_root, split='test', transform=preprocess),
        "urbancars": lambda: UrbancarsDataset(root=os.path.join(args.data_root, "urbancars"), split='test', transform=preprocess),
    }
    data_path = os.path.join(args.data_root, args.dataset)
    # if not os.path.exists(data_path):
    #     raise FileNotFoundError(f"Dataset path not found: {data_path}")
    ds = dataset_map.get(args.dataset, lambda: ImageFolder(root=data_path, transform=preprocess))()
    dataloader = DataLoader(
        ds, batch_size=args.batch_size, shuffle=False,
        num_workers=args.num_workers, pin_memory=True, prefetch_factor=2
    )
    return dataloader, len(ds)

def _create_text_embedding(args, model, device):
    """Generate text embeddings for zero-shot classification.

    Args:
        args: Command-line arguments containing model and dataset settings.
        model: CLIP model instance.
        device: Device to perform computations on (CPU or GPU).

    Returns:
        torch.Tensor: Text embeddings for all classes.
    """
    import class_text  # Assumes this module defines dataset class names
    if args.just_text == "simple":
        from CLIP_utils.template.simple_templates import OPENAI_IMAGENET_TEMPLATES
    elif args.just_text == "openai":
        from CLIP_utils.template.openai_templates import OPENAI_IMAGENET_TEMPLATES
    else:
        raise ValueError("Invalid mode for text embedding generation.")

    dataset_class_names = getattr(class_text, "DATASET_CLASS_NAMES", None)

    if args.dataset.lower() in ["imagenet","imagenet_w"]:
        class_names = dataset_class_names["imagenet"]
    elif args.dataset.lower() == "imagenet_a":
        from data.imagenet_a import thousand_k_to_200
        class_names = [dataset_class_names["imagenet"][i] for i, v in thousand_k_to_200.items() if v != -1]
    elif args.dataset.lower() == "waterbirds":
        class_names = dataset_class_names["waterbirds"]
    elif args.dataset.lower() in ["cocogbv1","cocogbv2"]:
        class_names = dataset_class_names["cocogb"]
    elif args.dataset.lower() == "urbancars":
        class_names = dataset_class_names["urbancars"]
    else:
        raise ValueError(f"Dataset '{args.dataset}' not found.")

    tokenizer = get_tokenizer(args.model)

    batch_size = min(len(class_names), 50)
    all_text_embeddings = []

    for i in tqdm(range(0, len(class_names), batch_size)):
        class_names_batch = class_names[i:i + batch_size]
        batch_texts = [
            template(classname) if callable(template) else template.format(classname)
            for classname in class_names_batch
            for template in OPENAI_IMAGENET_TEMPLATES
        ]
        try:
            batch_tokens = tokenizer(batch_texts).to(device)
            with torch.no_grad():
                batch_embeddings = model.encode_text(batch_tokens)
                batch_embeddings = F.normalize(batch_embeddings, dim=-1)
                batch_embeddings = batch_embeddings.view(len(class_names_batch), -1, batch_embeddings.size(-1))
                batch_embeddings = batch_embeddings.mean(dim=1)
                batch_embeddings = F.normalize(batch_embeddings, dim=-1)
                all_text_embeddings.append(batch_embeddings)
        except RuntimeError as e:
            logger.error(f"Error generating text embeddings for batch {i}: {e}")
            raise
        finally:
            del batch_tokens, batch_embeddings
            torch.cuda.empty_cache()

    text_embeddings = torch.cat(all_text_embeddings, dim=0)
    logger.info(
        f"Text embeddings generated for {len(class_names)} classes using {len(OPENAI_IMAGENET_TEMPLATES)} {args.just_text} templates.")
    return text_embeddings.to(device)

def _save_to_hdf5(f, start_idx, cumulative_attentions, cumulative_mlps,
                cumulative_image_embeddings, cumulative_labels_info, total_samples):
    """Save cumulative batch data to HDF5 file.

    Args:
        f: HDF5 file handle.
        start_idx: Starting index for saving data.
        cumulative_attentions: Cumulative attention data.
        cumulative_mlps: Cumulative MLP data.
        cumulative_image_embeddings: Cumulative image embeddings.
        cumulative_labels_info: Cumulative labels_info.
        total_samples: Total number of samples in the dataset.

    Returns:
        Updated starting index after saving.
    """
    if start_idx == 0:
        attentions_shape = (total_samples,) + cumulative_attentions.shape[1:]
        mlps_shape = (total_samples,) + cumulative_mlps.shape[1:]
        image_embeddings_shape = (total_samples,) + cumulative_image_embeddings.shape[1:]
        labels_info_shape = (total_samples,) + cumulative_labels_info.shape[1:]
        f.create_dataset("attentions", shape=attentions_shape, dtype='float32', compression="gzip")
        f.create_dataset("mlps", shape=mlps_shape, dtype='float32', compression="gzip")
        f.create_dataset("image_embeddings", shape=image_embeddings_shape, dtype='float32', compression="gzip")
        f.create_dataset("labels_info", shape=labels_info_shape, dtype='int64', compression="gzip")

    f["attentions"][start_idx:start_idx + len(cumulative_attentions)] = cumulative_attentions
    f["mlps"][start_idx:start_idx + len(cumulative_mlps)] = cumulative_mlps
    f["image_embeddings"][start_idx:start_idx + len(cumulative_image_embeddings)] = cumulative_image_embeddings
    f["labels_info"][start_idx:start_idx + len(cumulative_labels_info)] = cumulative_labels_info
    return start_idx + len(cumulative_labels_info)

def _modify_preprocess_with_watermark(preprocess, watermark_transform):
    """Modify preprocess pipeline to include watermark transformation.

    Args:
        preprocess: Original preprocessing pipeline.
        watermark_transform: Watermark transformation to insert.

    Returns:
        Modified preprocessing pipeline.
    """
    from torchvision.transforms import Compose, ToTensor
    if not isinstance(preprocess, Compose):
        raise ValueError("preprocess is not a Compose object")
    transforms = preprocess.transforms
    for i, transform in enumerate(transforms):
        if isinstance(transform, ToTensor):
            break
    else:
        raise ValueError("ToTensor not found in preprocess transforms")
    new_transforms = transforms[:i+1] + [watermark_transform] + transforms[i+1:]
    return Compose(new_transforms)

def main(args):
    """Main function to process dataset and perform zero-shot classification.

    Args:
        args: Command-line arguments.
    """
    try: 
        subfolder_name = f"{args.model}_{args.dataset}"
        output_dir = Path(args.output_dir) / subfolder_name
        output_dir.mkdir(parents=True, exist_ok=True)

        if not args.just_text:
            logging.basicConfig(
                level=logging.INFO,
                format="%(asctime)s - %(levelname)s - %(message)s",
                handlers=[
                    logging.FileHandler(output_dir / "Console_Info.log", mode='w'),
                    logging.StreamHandler()
                ]
            )
        else:
            logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")

        os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512'

        device = torch.device(f"cuda:{args.cuda_id}" if torch.cuda.is_available() else "cpu")

        model, _, preprocess = create_model_and_transforms(args.model, pretrained=args.pretrained)
        model.to(device)
        model.eval()
        prs = hook_prs_logger(model, device)

        model_parameters = sum(p.numel() for p in model.parameters())
        logger.info(f"Model parameters: {model_parameters}")
        logger.info(f"Context length: {model.context_length}")
        logger.info(f"Vocabulary size: {model.vocab_size}")
        logger.info(f"Number of transformer residual blocks: {len(model.visual.transformer.resblocks)}")

        if args.just_text:
            text_embeddings = _create_text_embedding(args, model, device)
            with h5py.File(output_dir / f"{args.just_text}_text.h5", 'a') as f:
                if "text_embeddings" in f:
                    del f["text_embeddings"]
                f.create_dataset("text_embeddings", data=text_embeddings.detach().cpu().numpy(), dtype='float32', compression="gzip")
            logger.info(f"Text embeddings saved as {args.just_text}_text.h5")
            return

        # if args.dataset == "imagenet_w":
        #     logger.info("Watermarks enabled for dataset")
        #     try:
        #         watermark_transform = AddWatermark(image_size=model.visual.image_size[0])
        #         preprocess = _modify_preprocess_with_watermark(preprocess, watermark_transform)
        #         logger.info("Preprocess pipeline updated with watermark transformation.")
        #     except Exception as e:
        #         logger.error(f"Failed to modify preprocess with watermark: {e}")
        #         raise

        data_loader, total_samples = _create_dataloaders(args, preprocess)
        logger.info(f"Dataset '{args.dataset}' loaded successfully, total samples: {total_samples}")

        with h5py.File(output_dir / "data.h5", 'w') as f:
            start_idx = 0
            cumulative_attentions = []
            cumulative_mlps = []
            cumulative_image_embeddings = []
            cumulative_labels_info = []
            save_every_n_batches = max(1, 50 // args.batch_size)
            with torch.inference_mode():
                for batch_idx, (images, labels_info) in enumerate(tqdm(data_loader)):
                    prs.reinit()
                    image_embeddings = model.encode_image(images.to(device), attn_method="head", normalize=False)
                    
                    attentions, mlps = prs.finalize(representation=image_embeddings)
                    attentions = attentions.sum(dim=(1, 3)).detach().cpu().numpy()
                    image_embeddings = F.normalize(image_embeddings, dim=-1).detach().cpu().numpy()
                    cumulative_attentions.append(attentions)
                    cumulative_mlps.append(mlps.detach().cpu().numpy())
                    cumulative_image_embeddings.append(image_embeddings)
                    cumulative_labels_info.append(labels_info.numpy())

                    if (batch_idx + 1) % save_every_n_batches == 0 or (batch_idx + 1) == len(data_loader):
                        cumulative_attentions_np = np.concatenate(cumulative_attentions, axis=0)
                        cumulative_mlps_np = np.concatenate(cumulative_mlps, axis=0)
                        cumulative_image_embeddings_np = np.concatenate(cumulative_image_embeddings, axis=0)
                        cumulative_labels_info_np = np.concatenate(cumulative_labels_info, axis=0)
                        start_idx = _save_to_hdf5(f, start_idx, cumulative_attentions_np, cumulative_mlps_np,
                                                cumulative_image_embeddings_np, cumulative_labels_info_np, total_samples)
                        cumulative_attentions = []
                        cumulative_mlps = []
                        cumulative_image_embeddings = []
                        cumulative_labels_info = []

                    del images, labels_info, attentions, mlps, image_embeddings
                    torch.cuda.empty_cache()

            logger.info(f"Saved attentions, mlps, image_embeddings, labels_info to {output_dir}")

    except Exception as e:
        logger.exception(f"Execution failed: {e}")

if __name__ == "__main__":
    parser = get_parser_info()
    args = parser.parse_args()
    set_seed(42)
    main(args)

def pil_loader(path: str) -> Image.Image:
    """Load a PIL image and convert it to RGB format.
    
    Args:
        path: Path to the image file.
        
    Returns:
        PIL Image in RGB format.
    """
    try:
        with open(path, "rb") as f:
            img = Image.open(f)
            return img.convert("RGB")
    except (IOError, OSError) as e:
        print(f"Warning: Unable to load image {path}: {e}")
        # Return a small black image as a fallback
        return Image.new('RGB', (224, 224), (0, 0, 0))

