import argparse
import os
import sys
import random
import numpy as np
import torch
import h5py
from torch.nn import functional as F
from pathlib import Path
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from tqdm import tqdm
import logging
from mixup import ImageMixup
from cutmix import ImageCutMix
from augmix import ImageAugMix
from cutout import ImageCutout

# Use relative import path instead of absolute path
project_root = os.path.abspath(os.getcwd())
sys.path.append(project_root)

from data.urbancars import UrbancarsDataset

from CLIP_utils.factory import create_model_and_transforms, get_tokenizer


# Set HF mirror URL to accelerate model downloads
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"

# Configure global logger
logger = logging.getLogger(__name__)

def set_seed(seed):
    """Set global random seed for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

def get_parser_info():
    """Configure command-line argument parser."""
    parser = argparse.ArgumentParser(
        description="Load and process dataset for CLIP zero-shot classification.",
        add_help=True
    )
    # Model parameters
    parser.add_argument("--model", default="ViT-B-16", type=str, help="Name of the model to use.")
    parser.add_argument("--pretrained", default="laion2b_s34b_b88k", type=str, help="Pretrained weights to load.")
    parser.add_argument("--batch_size", type=int, default=32, help="Batch size for data loading.")
    # Data parameters
    parser.add_argument("--num_workers", type=int, default=4, help="Number of workers for data loading.")
    parser.add_argument("--dataset", default="urbancars", type=str, help="Name of the dataset.")
    parser.add_argument("--data_root", type=str, default="./data", help="Path to the dataset root directory.")
    # Output parameters
    parser.add_argument("--output_dir", type=str, default="./results", help="Path to save output results.")
    parser.add_argument("--cuda_id", type=str, default="3", help="Device to run the model on.")
    parser.add_argument("--augment", type=str, default="cutout", help="Augmentation method to use.")
    
    return parser

def _create_dataloaders(args, preprocess):
    """Create a DataLoader for the specified dataset.

    Args:
        args: Command-line arguments containing dataset and loader settings.
        preprocess: Image preprocessing pipeline.

    Returns:
        Tuple of DataLoader instance and total number of samples.
    """
    dataset_map = {
        "urbancars": lambda: UrbancarsDataset(root=os.path.join(args.data_root, "urbancars"), split='test', transform=preprocess),
    }
    data_path = os.path.join(args.data_root, args.dataset)
    # if not os.path.exists(data_path):
    #     raise FileNotFoundError(f"Dataset path not found: {data_path}")
    ds = dataset_map.get(args.dataset, lambda: ImageFolder(root=data_path, transform=preprocess))()
    dataloader = DataLoader(
        ds, batch_size=args.batch_size, shuffle=False,
        num_workers=args.num_workers, pin_memory=True, prefetch_factor=2
    )
    return dataloader, len(ds)

def eval_urbancars(group_correct_counts, group_total_counts, label_info, pred):
    """
    Evaluate predictions for UrbanCars dataset by computing correct and total counts for each group.
    
    Args:
        group_correct_counts (dict): Dictionary tracking correct predictions per group.
        group_total_counts (dict): Dictionary tracking total samples per group.
        label_info (tensor or array): Array with shape (N, 3) containing labels, places, and coobjs.
        pred (tensor or array): Predicted labels.
    
    Returns:
        tuple: Updated group_correct_counts and group_total_counts.
    """
    labels = label_info[:, 0].numpy() if torch.is_tensor(label_info) else label_info[:, 0]
    places = label_info[:, 1].numpy() if torch.is_tensor(label_info) else label_info[:, 1]
    coobjs = label_info[:, 2].numpy() if torch.is_tensor(label_info) else label_info[:, 2]
    pred = pred.numpy() if torch.is_tensor(pred) else pred

    id_mask = ((labels == 0) & (places == 0) & (coobjs == 0)) | ((labels == 1) & (places == 1) & (coobjs == 1))
    bg_mask = ((labels == 0) & (places == 1) & (coobjs == 0)) | ((labels == 1) & (places == 0) & (coobjs == 1))
    coobj_mask = ((labels == 0) & (places == 0) & (coobjs == 1)) | ((labels == 1) & (places == 1) & (coobjs == 0))
    bg_coobj_mask = ((labels == 0) & (places == 1) & (coobjs == 1)) | ((labels == 1) & (places == 0) & (coobjs == 0))

    masks = [id_mask, bg_mask, coobj_mask, bg_coobj_mask]

    for group, mask in enumerate(masks):
        correct_count = np.sum(pred[mask] == labels[mask])
        total_count = np.sum(mask)
        group_correct_counts[group] += int(correct_count)
        group_total_counts[group] += int(total_count)

    return group_correct_counts, group_total_counts

def main(args):
    """
    Main function to process UrbanCars dataset and perform zero-shot classification.
    
    Args:
        args: Command-line arguments.
    """
    try: 
        subfolder_name = f"{args.model}_{args.dataset}"
        output_dir = Path(args.output_dir) / subfolder_name
        Path(f'mixed_images/{args.dataset}_{args.augment}').mkdir(parents=True, exist_ok=True)
        
        logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
        os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512'
        device = torch.device(f"cuda:{args.cuda_id}")
        model, _, preprocess = create_model_and_transforms(args.model, pretrained=args.pretrained)
        model.to(device)
        model.eval()
        data_loader, total_samples = _create_dataloaders(args, preprocess)
        tokenizer = get_tokenizer(args.model)
        logger.info(f"Dataset '{args.dataset}' loaded successfully, total samples: {total_samples}")
        
        base_dir = 'data/places/train'        
        scene_list = [
            "street", "parking_lot", "residential_neighborhood", "industrial_area", "garage-outdoor",
            "farm", "field_road", "forest_path", "mountain_path", "junkyard"
        ]

        
        f = h5py.File(f'{output_dir}/openai_text.h5', 'r')
        positive_text_embeddings = f['text_embeddings'][:]
        if args.augment == "mixup":
            negative_texts = ["a photo of geometric shapes", "a black photo", "a mixup image dominated by background", "a photo without visible objects"]
        elif args.augment == "cutmix":
            negative_texts = ["a photo of car", "a black photo", "a cutmix image dominated by background", "a photo without objects"]
        elif args.augment == "augmix":
            negative_texts = ["a photo of car", "a photo of geometric shapes", "a black photo", "a mixup image dominated by background", "a photo without objects"]
        elif args.augment == "cutout":
            negative_texts = ["a photo of car", "a completely black photo", "a cutout image dominated by background", "a photo without objects"]
        negative_text_embeddings = tokenizer(negative_texts).to(device)
        negative_text_embeddings = model.encode_text(negative_text_embeddings)
        negative_text_embeddings = F.normalize(negative_text_embeddings, dim=-1).detach().cpu().numpy()
        
        text_embeddings = (positive_text_embeddings, negative_text_embeddings)
        
        if args.augment == "mixup":
            mixup_processor = ImageMixup(base_dir=base_dir, clip_model=model, clip_preprocess=preprocess, tokenizer=tokenizer, save_dir=f'mixed_images/{args.dataset}_{args.augment}', text_embeddings=text_embeddings, device=device)
        elif args.augment == "cutmix":
            cutmix_processor = ImageCutMix(base_dir=base_dir, clip_model=model, clip_preprocess=preprocess, tokenizer=tokenizer, save_dir=f'mixed_images/{args.dataset}_{args.augment}', text_embeddings=text_embeddings, device=device)
        elif args.augment == "augmix":
            augmix_processor = ImageAugMix(clip_model=model, clip_preprocess=preprocess, 
                                         tokenizer=tokenizer, save_dir=f'mixed_images/{args.dataset}_{args.augment}', text_embeddings=text_embeddings, device=device)
        elif args.augment == "cutout":
            cutout_processor = ImageCutout(clip_model=model, clip_preprocess=preprocess, 
                                         tokenizer=tokenizer, save_dir=f'mixed_images/{args.dataset}_{args.augment}', text_embeddings=text_embeddings, device=device)

        # Initialize group counters for tracking accuracy
        group_correct_counts = {i: 0 for i in range(4)}
        group_total_counts = {i: 0 for i in range(4)}
        
        with torch.inference_mode():
            with h5py.File(f'mixed_images/{args.dataset}_{args.augment}/mean_pos_sims.h5', 'w') as f0: 
                f0.create_dataset('mean_pos_sims', shape=(total_samples, positive_text_embeddings.shape[0]), dtype='float32')
                for batch_idx, (images, labels_info) in enumerate(tqdm(data_loader)):
                    # Get the paths of images in the current batch
                    if hasattr(data_loader.dataset, 'samples'):
                        batch_start = batch_idx * args.batch_size
                        batch_end = min(batch_start + args.batch_size, len(data_loader.dataset))
                        image_paths = [data_loader.dataset.samples[i][0] for i in range(batch_start, batch_end)]
                    
                    batch_mean_pos_sims = []
                    for i in range(len(images)):
                        image_path = image_paths[i]  
                        if args.augment == "mixup":
                            pos_sims = mixup_processor.process(image_path, scene_list, image_enable=False)
                        elif args.augment == "cutmix":
                            pos_sims = cutmix_processor.process(image_path, scene_list, image_enable=False)
                        elif args.augment == "augmix":
                            pos_sims = augmix_processor.process(image_path, image_enable=False)
                        elif args.augment == "cutout":
                            pos_sims = cutout_processor.process(image_path, image_enable=False)
                        pos_sims = np.array(pos_sims)
                        # Calculate mean similarity scores
                        mean_pos_sims = np.mean(pos_sims, axis=0)
                        batch_mean_pos_sims.append(mean_pos_sims)
                    
                    batch_mean_pos_sims = np.array(batch_mean_pos_sims)
                    preds = batch_mean_pos_sims.argmax(axis=1)
                    # Evaluate predictions using UrbanCars-specific evaluation
                    group_correct_counts, group_total_counts = eval_urbancars(group_correct_counts, group_total_counts, labels_info, preds)
                    # Save similarity scores
                    f0['mean_pos_sims'][batch_start:batch_end] = batch_mean_pos_sims
                    del batch_mean_pos_sims, images, labels_info
                    torch.cuda.empty_cache()
            
            # Calculate overall and group-specific accuracies
            overall_correct = sum(group_correct_counts.values())
            overall_total = sum(group_total_counts.values())
            overall_accuracy = (overall_correct / overall_total) * 100 if overall_total > 0 else 0
            group_accuracies = {}
            for group in range(4):
                if group_total_counts[group] > 0:
                    group_accuracies[group] = (group_correct_counts[group] / group_total_counts[group]) * 100
                else:
                    group_accuracies[group] = 0

            # Calculate gaps compared to ID group accuracy
            id_accuracy = group_accuracies[0]
            bg_gap = id_accuracy - group_accuracies[1]
            coobj_gap = id_accuracy - group_accuracies[2]
            bg_coobj_gap = id_accuracy - group_accuracies[3]

            # Log results
            logger.info(f"Saved mean_pos_sims to mixed_images/{args.dataset}_{args.augment}/mean_pos_sims.h5")
            logger.info(f"Overall accuracy: {overall_accuracy:.2f}%")
            logger.info(f"ID group accuracy: {group_accuracies[0]:.2f}%")
            logger.info(f"BG group accuracy: {group_accuracies[1]:.2f}% (GAP: {bg_gap:.2f}%)")
            logger.info(f"CoObj group accuracy: {group_accuracies[2]:.2f}% (GAP: {coobj_gap:.2f}%)")
            logger.info(f"BG_CoObj group accuracy: {group_accuracies[3]:.2f}% (GAP: {bg_coobj_gap:.2f}%)")

    except Exception as e:
        logger.exception(f"Execution failed: {e}")

if __name__ == "__main__":
    parser = get_parser_info()
    args = parser.parse_args()
    set_seed(42)
    main(args)
