import argparse
import os
import sys
import random
import numpy as np
import torch
import h5py
from torch.nn import functional as F
from pathlib import Path
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from tqdm import tqdm
import logging
from mixup import ImageMixup
from cutmix import ImageCutMix
from augmix import ImageAugMix
from cutout import ImageCutout

# Use relative import path instead of absolute path
project_root = os.path.abspath(os.getcwd())
sys.path.append(project_root)
from data.waterbirds import WaterbirdsDataset
from data.NICO import NICODataset
from class_text import NICO as NICO_CLASS_INFO

from CLIP_utils.factory import create_model_and_transforms, get_tokenizer


# Set HF mirror URL to accelerate model downloads
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"

# Configure global logger
logger = logging.getLogger(__name__)

def set_seed(seed):
    """Set global random seed for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

def get_parser_info():
    """Configure command-line argument parser."""
    parser = argparse.ArgumentParser(
        description="Load and process dataset for CLIP zero-shot classification.",
        add_help=True
    )
    # Model parameters
    parser.add_argument("--model", default="ViT-B-16", type=str, help="Name of the model to use.")
    parser.add_argument("--pretrained", default="laion2b_s34b_b88k", type=str, help="Pretrained weights to load.")
    parser.add_argument("--batch_size", type=int, default=32, help="Batch size for data loading.")
    # Data parameters
    parser.add_argument("--num_workers", type=int, default=4, help="Number of workers for data loading.")
    parser.add_argument("--dataset", default="waterbirds", type=str, help="Name of the dataset.")
    parser.add_argument("--data_root", type=str, default="./data", help="Path to the dataset root directory.")
    # Output parameters
    parser.add_argument("--output_dir", type=str, default="./results", help="Path to save output results.")
    parser.add_argument("--cuda_id", type=str, default="3", help="Device to run the model on.")
    parser.add_argument("--augment", type=str, default="cutout", help="Augmentation method to use.")
    return parser

def _create_dataloaders(args, preprocess):
    """Create a DataLoader for the specified dataset.

    Args:
        args: Command-line arguments containing dataset and loader settings.
        preprocess: Image preprocessing pipeline.

    Returns:
        Tuple of DataLoader instance and total number of samples.
    """
    dataset_map = {
        "nico": lambda: NICODataset(root=os.path.join(args.data_root, "NICO"), transform=preprocess),
    }
    data_path = os.path.join(args.data_root, args.dataset)
    # if not os.path.exists(data_path):
    #     raise FileNotFoundError(f"Dataset path not found: {data_path}")
    ds = dataset_map.get(args.dataset, lambda: ImageFolder(root=data_path, transform=preprocess))()
    dataloader = DataLoader(
        ds, batch_size=args.batch_size, shuffle=False,
        num_workers=args.num_workers, pin_memory=True, prefetch_factor=2
    )
    return dataloader, len(ds)

def eval_waterbirds(group_correct_counts, group_total_counts, label_info, pred):
    """
    Evaluate Waterbirds dataset accuracy by groups
    
    Args:
        group_correct_counts: Dictionary tracking correct predictions per group
        group_total_counts: Dictionary tracking total samples per group
        label_info: Label information (class, place)
        pred: Prediction results
    
    Returns:
        Updated counts dictionaries
    """
    # Ensure data type is numpy array
    labels = label_info[:, 0].numpy() if torch.is_tensor(label_info) else label_info[:, 0]
    places = label_info[:, 1].numpy() if torch.is_tensor(label_info) else label_info[:, 1]
    pred = pred.numpy() if torch.is_tensor(pred) else pred

    for group in range(4):
        # Create mask
        mask = (labels == (group // 2)) & (places == (group % 2))
        
        # Calculate number of correct predictions
        correct_count = np.sum(pred[mask] == labels[mask])
        total_count = np.sum(mask)
        
        # Update counts
        group_correct_counts[group] += int(correct_count)
        group_total_counts[group] += int(total_count)
    
    return group_correct_counts, group_total_counts

def eval_nico(class_correct_counts, class_total_counts, context_correct_counts, context_total_counts, label_info, pred):
    """
    Evaluate NICO dataset accuracy by class and context
    
    Args:
        class_correct_counts: Dictionary tracking correct predictions per class
        class_total_counts: Dictionary tracking total samples per class
        context_correct_counts: Dictionary tracking correct predictions per class/context
        context_total_counts: Dictionary tracking total samples per class/context
        label_info: Label information (class, context)
        pred: Prediction results
    
    Returns:
        Updated counts dictionaries
    """
    # Ensure data type is numpy array
    labels = label_info[:, 0].numpy() if torch.is_tensor(label_info) else label_info[:, 0]
    contexts = label_info[:, 1].numpy() if torch.is_tensor(label_info) else label_info[:, 1]
    pred = pred.numpy() if torch.is_tensor(pred) else pred

    # Calculate accuracy for each class
    for class_idx in np.unique(labels):
        # Create class mask
        class_mask = (labels == class_idx)
        
        # Calculate number of correct predictions
        correct_count = np.sum(pred[class_mask] == labels[class_mask])
        total_count = np.sum(class_mask)
        
        # Update class counts
        if class_idx not in class_correct_counts:
            class_correct_counts[class_idx] = 0
            class_total_counts[class_idx] = 0
        class_correct_counts[class_idx] += int(correct_count)
        class_total_counts[class_idx] += int(total_count)
        
        # Calculate accuracy for each context in this class
        for ctx_idx in np.unique(contexts[class_mask]):
            # Create context mask
            ctx_mask = class_mask & (contexts == ctx_idx)
            
            # Calculate number of correct predictions
            ctx_correct_count = np.sum(pred[ctx_mask] == labels[ctx_mask])
            ctx_total_count = np.sum(ctx_mask)
            
            # Update context counts
            if class_idx not in context_correct_counts:
                context_correct_counts[class_idx] = {}
                context_total_counts[class_idx] = {}
            if ctx_idx not in context_correct_counts[class_idx]:
                context_correct_counts[class_idx][ctx_idx] = 0
                context_total_counts[class_idx][ctx_idx] = 0
            context_correct_counts[class_idx][ctx_idx] += int(ctx_correct_count)
            context_total_counts[class_idx][ctx_idx] += int(ctx_total_count)
    
    return class_correct_counts, class_total_counts, context_correct_counts, context_total_counts

def main(args):
    """Main function to process dataset and perform zero-shot classification.

    Args:
        args: Command-line arguments.
    """
    try: 
        subfolder_name = f"{args.model}_{args.dataset}"
        output_dir = Path(args.output_dir) / subfolder_name
        
        Path(f'mixed_images/{args.dataset}_{args.augment}').mkdir(parents=True, exist_ok=True)

        logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")

        os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512'

        device = torch.device(f"cuda:{args.cuda_id}")

        model, _, preprocess = create_model_and_transforms(args.model, pretrained=args.pretrained)
        model.to(device)
        model.eval()

        data_loader, total_samples = _create_dataloaders(args, preprocess)
        tokenizer = get_tokenizer(args.model)
        
        logger.info(f"Dataset '{args.dataset}' loaded successfully, total samples: {total_samples}")
        
        # For NICO dataset, use a diverse list of scenes corresponding to different contexts
        base_dir = 'data/places/train'        
        scene_list = [
            "forest_path",         
            "snowfield",         
            "beach",           
            "river",        
            "park",             
            "residential_neighborhood",  
            "street",              
            "airport_terminal",    
            "garage-outdoor",      
            "train_station-platform",  
            "bridge",              
            "corral",              
            "amphitheater",  
            "waterfall",      
            "mountain_snowy",     
            "desert_road"         
        ]
        
        f = h5py.File(f'{output_dir}/openai_text.h5', 'r')
        positive_text_embeddings = f['text_embeddings'][:]
        if args.augment == "mixup":
            negative_texts = ["a photo without a clear subject", "a blended image", "a mixup image dominated by background", "a photo without visible objects"]
        elif args.augment == "cutmix":
            negative_texts = ["a photo with parts missing", "a composite image", "a cutmix image", "a photo with replaced regions"]
        elif args.augment == "augmix":
            negative_texts = ["a distorted photo", "an augmented image", "a photo with applied effects", "an image with visual modifications"]
        elif args.augment == "cutout":
            negative_texts = ["a photo with holes", "an image with black patches", "a cutout image", "a photo with sections removed"]
        negative_text_embeddings = tokenizer(negative_texts).to(device)
        negative_text_embeddings = model.encode_text(negative_text_embeddings)
        negative_text_embeddings = F.normalize(negative_text_embeddings, dim=-1).detach().cpu().numpy()
        
        text_embeddings = (positive_text_embeddings, negative_text_embeddings)
        
        if args.augment == "mixup":
            mixup_processor = ImageMixup(base_dir=base_dir, clip_model=model, clip_preprocess=preprocess, 
                                        tokenizer=tokenizer, save_dir=f'mixed_images/{args.dataset}_{args.augment}', text_embeddings=text_embeddings, device=device)
        elif args.augment == "cutmix":
            cutmix_processor = ImageCutMix(base_dir=base_dir, clip_model=model, clip_preprocess=preprocess, 
                                        tokenizer=tokenizer, save_dir=f'mixed_images/{args.dataset}_{args.augment}', text_embeddings=text_embeddings, device=device)
        elif args.augment == "augmix":
            augmix_processor = ImageAugMix(clip_model=model, clip_preprocess=preprocess, 
                                        tokenizer=tokenizer, save_dir=f'mixed_images/{args.dataset}_{args.augment}', text_embeddings=text_embeddings, device=device)
        elif args.augment == "cutout":
            cutout_processor = ImageCutout(clip_model=model, clip_preprocess=preprocess, 
                                        tokenizer=tokenizer, save_dir=f'mixed_images/{args.dataset}_{args.augment}', text_embeddings=text_embeddings, device=device)

        # Initialize class and context counters
        class_correct_counts = {}
        class_total_counts = {}
        context_correct_counts = {}
        context_total_counts = {}
        
        with torch.inference_mode():
            with h5py.File(f'mixed_images/{args.dataset}_{args.augment}/mean_pos_sims.h5', 'w') as f0: 
                f0.create_dataset('mean_pos_sims', shape=(total_samples, positive_text_embeddings.shape[0]), dtype='float32')
                for batch_idx, (images, labels_info) in enumerate(tqdm(data_loader)):
                    # Get the paths of images in the current batch
                    if hasattr(data_loader.dataset, 'samples'):
                        batch_start = batch_idx * args.batch_size
                        batch_end = min(batch_start + args.batch_size, len(data_loader.dataset))
                        image_paths = [data_loader.dataset.samples[i][0] for i in range(batch_start, batch_end)]
                    
                    batch_mean_pos_sims = []
                    
                    for i in range(len(images)):
                        image_path = image_paths[i]  
                        if args.augment == "mixup":
                            pos_sims = mixup_processor.process(image_path, scene_list, image_enable=False)
                        elif args.augment == "cutmix":
                            pos_sims = cutmix_processor.process(image_path, scene_list, image_enable=False)
                        elif args.augment == "augmix":
                            pos_sims = augmix_processor.process(image_path, image_enable=False)
                        elif args.augment == "cutout":
                            pos_sims = cutout_processor.process(image_path, image_enable=False)
                        pos_sims = np.array(pos_sims)
                        
                        # Convert pos_sims to numpy array and calculate mean
                        mean_pos_sims = np.mean(pos_sims, axis=0)
                        batch_mean_pos_sims.append(mean_pos_sims)
                    
                    batch_mean_pos_sims = np.array(batch_mean_pos_sims)  # Convert list to NumPy array
                    preds = batch_mean_pos_sims.argmax(axis=1)  # No need for keepdims and squeeze
                    
                    # Choose evaluation function based on dataset
                    if args.dataset == "waterbirds":
                        group_correct_counts, group_total_counts = eval_waterbirds(group_correct_counts, group_total_counts, labels_info, preds)
                    elif args.dataset == "nico":
                        class_correct_counts, class_total_counts, context_correct_counts, context_total_counts = eval_nico(
                            class_correct_counts, class_total_counts, context_correct_counts, context_total_counts, labels_info, preds)
                    
                    # Add current batch's mean_pos_sims to total list
                    f0['mean_pos_sims'][batch_start:batch_end] = batch_mean_pos_sims
                    
                    del batch_mean_pos_sims, images, labels_info
                    torch.cuda.empty_cache()   
            

            # Calculate overall accuracy
            total_correct = sum(class_correct_counts.values())
            total_samples = sum(class_total_counts.values())
            overall_accuracy = (total_correct / total_samples) * 100 if total_samples > 0 else 0
            
            # Output overall accuracy
            logger.info(f"Overall accuracy: {overall_accuracy:.2f}%")
            
            # Sort classes by accuracy in descending order
            sorted_classes = sorted(
                [(idx, class_correct_counts[idx] / class_total_counts[idx] * 100 if class_total_counts[idx] > 0 else 0) 
                    for idx in class_correct_counts], 
                key=lambda x: x[1], 
                reverse=True
            )
            
            # Output accuracy for each class
            logger.info("Class-wise accuracies:")
            for class_idx, acc in sorted_classes:
                # Get class name
                class_name = ""
                try:
                    if class_idx in NICO_CLASS_INFO:
                        class_name = list(NICO_CLASS_INFO[class_idx].keys())[0]
                except:
                    pass
                
                if not class_name:
                    class_name = f"Class {class_idx}"
                    
                logger.info(f"{class_name}: {acc:.2f}% ({class_correct_counts[class_idx]}/{class_total_counts[class_idx]} samples)")
            
            # Output accuracy for each context
            logger.info("\nContext-wise accuracies:")
            for class_idx in sorted(context_correct_counts.keys()):
                class_name = ""
                try:
                    if class_idx in NICO_CLASS_INFO:
                        class_name = list(NICO_CLASS_INFO[class_idx].keys())[0]
                except:
                    pass
                
                if not class_name:
                    class_name = f"Class {class_idx}"
                
                logger.info(f"\n{class_name}:")
                
                # Get context name mapping for the class
                context_names = {}
                try:
                    if class_idx in NICO_CLASS_INFO:
                        class_dict = NICO_CLASS_INFO[class_idx]
                        class_name = list(class_dict.keys())[0]
                        contexts = class_dict[class_name]
                        context_names = {idx: name for idx, name in enumerate(contexts)}
                except:
                    pass
                
                # Sort contexts by accuracy in descending order
                sorted_contexts = sorted(
                    [(ctx_idx, context_correct_counts[class_idx][ctx_idx] / context_total_counts[class_idx][ctx_idx] * 100 
                        if context_total_counts[class_idx][ctx_idx] > 0 else 0) 
                        for ctx_idx in context_correct_counts[class_idx]], 
                    key=lambda x: x[1], 
                    reverse=True
                )
                
                for ctx_idx, acc in sorted_contexts:
                    ctx_name = context_names.get(ctx_idx, f"Context {ctx_idx}")
                    correct = context_correct_counts[class_idx][ctx_idx]
                    total = context_total_counts[class_idx][ctx_idx]
                    logger.info(f"  {ctx_name}: {acc:.2f}% ({correct}/{total} samples)")
                
            logger.info(f"Saved mean_pos_sims to mixed_images/{args.dataset}_{args.augment}/mean_pos_sims.h5")

            # Save results to txt file after calculating and outputting results
            output_path = Path("./output/NICO")
            output_path.mkdir(parents=True, exist_ok=True)
            results_file = output_path / f"{subfolder_name}_{args.augment}_results.txt"
            with open(results_file, 'w', encoding='utf-8') as f:
                # Write overall accuracy
                f.write(f"Dataset: {args.dataset}\n")
                f.write(f"Augment method: {args.augment}\n")
                f.write(f"Overall accuracy: {overall_accuracy:.2f}%\n\n")
                
                # Write accuracy for each class
                f.write("Class-wise accuracies:\n")
                for class_idx, acc in sorted_classes:
                    # Get class name
                    class_name = ""
                    try:
                        if class_idx in NICO_CLASS_INFO:
                            class_name = list(NICO_CLASS_INFO[class_idx].keys())[0]
                    except:
                        pass
                    
                    if not class_name:
                        class_name = f"Class {class_idx}"
                        
                    f.write(f"{class_name}: {acc:.2f}% ({class_correct_counts[class_idx]}/{class_total_counts[class_idx]} samples)\n")
                
                # Write accuracy for each class and context
                f.write("\nContext-wise accuracies for each class:\n")
                for class_idx in sorted(context_correct_counts.keys()):
                    class_name = ""
                    try:
                        if class_idx in NICO_CLASS_INFO:
                            class_name = list(NICO_CLASS_INFO[class_idx].keys())[0]
                    except:
                        pass
                    
                    if not class_name:
                        class_name = f"Class {class_idx}"
                        
                    f.write(f"\n{class_name} (overall: {class_correct_counts[class_idx] / class_total_counts[class_idx] * 100:.2f}%):\n")
                    
                    # Get context name mapping for the class
                    context_names = {}
                    try:
                        if class_idx in NICO_CLASS_INFO:
                            class_dict = NICO_CLASS_INFO[class_idx]
                            class_name = list(class_dict.keys())[0]
                            contexts = class_dict[class_name]
                            context_names = {idx: name for idx, name in enumerate(contexts)}
                    except:
                        pass
                    
                    # Sort contexts by accuracy in descending order
                    sorted_contexts = sorted(
                        [(ctx_idx, context_correct_counts[class_idx][ctx_idx] / context_total_counts[class_idx][ctx_idx] * 100 
                            if context_total_counts[class_idx][ctx_idx] > 0 else 0) 
                            for ctx_idx in context_correct_counts[class_idx]], 
                        key=lambda x: x[1], 
                        reverse=True
                    )
                    
                    for ctx_idx, acc in sorted_contexts:
                        ctx_name = context_names.get(ctx_idx, f"Context {ctx_idx}")
                        correct = context_correct_counts[class_idx][ctx_idx]
                        total = context_total_counts[class_idx][ctx_idx]
                        f.write(f"  {ctx_name}: {acc:.2f}% ({correct}/{total} samples)\n")

            logger.info(f"Results saved to {results_file}")

    except Exception as e:
        logger.exception(f"Execution failed: {e}")

if __name__ == "__main__":
    parser = get_parser_info()
    args = parser.parse_args()
    set_seed(42)
    main(args)
