import argparse
import os
import sys
import random
import numpy as np
import torch
import h5py
from torch.nn import functional as F
from pathlib import Path
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from tqdm import tqdm
import logging
from mixup import ImageMixup
from cutmix import ImageCutMix
from augmix import ImageAugMix
from cutout import ImageCutout

from collections import defaultdict 

# Use relative import path instead of absolute path
project_root = os.path.abspath(os.getcwd())
sys.path.append(project_root)
from data.waterbirds import WaterbirdsDataset
from data.COCO_GB_V1 import COCO_GB_V1_dataset
from data.COCO_GB_V2 import COCO_GB_V2_dataset

from CLIP_utils.factory import create_model_and_transforms, get_tokenizer


# Set HF mirror URL to accelerate model downloads
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"

# Configure global logger
logger = logging.getLogger(__name__)

def set_seed(seed):
    """Set global random seed for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

def get_parser_info():
    """Configure command-line argument parser."""
    parser = argparse.ArgumentParser(
        description="Load and process dataset for CLIP zero-shot classification.",
        add_help=True
    )
    # Model parameters
    parser.add_argument("--model", default="ViT-B-16", type=str, help="Name of the model to use.")
    parser.add_argument("--pretrained", default="laion2b_s34b_b88k", type=str, help="Pretrained weights to load.")
    parser.add_argument("--batch_size", type=int, default=32, help="Batch size for data loading.")
    # Data parameters
    parser.add_argument("--num_workers", type=int, default=4, help="Number of workers for data loading.")
    parser.add_argument("--dataset", default="cocogbv2", type=str, help="Name of the dataset.")
    parser.add_argument("--data_root", type=str, default="./data", help="Path to the dataset root directory.")
    # Output parameters
    parser.add_argument("--output_dir", type=str, default="./results", help="Path to save output results.")
    parser.add_argument("--cuda_id", type=str, default="3", help="Device to run the model on.")
    parser.add_argument("--augment", type=str, default="cutout", help="Augmentation method to use.")
    return parser

def _create_dataloaders(args, preprocess):
    """Create a DataLoader for the specified dataset.

    Args:
        args: Command-line arguments containing dataset and loader settings.
        preprocess: Image preprocessing pipeline.

    Returns:
        Tuple of DataLoader instance and total number of samples.
    """
    dataset_map = {
        "cocogbv1": lambda: COCO_GB_V1_dataset(root=args.data_root, split='test', transform=preprocess),
        "cocogbv2": lambda: COCO_GB_V2_dataset(root=args.data_root, split='test', transform=preprocess),
        "waterbirds": lambda: WaterbirdsDataset(root=os.path.join(args.data_root, "waterbirds"), split="test", transform=preprocess),
    }
    data_path = os.path.join(args.data_root, args.dataset)
    # if not os.path.exists(data_path):
    #     raise FileNotFoundError(f"Dataset path not found: {data_path}")
    ds = dataset_map.get(args.dataset, lambda: ImageFolder(root=data_path, transform=preprocess))()
    dataloader = DataLoader(
        ds, batch_size=args.batch_size, shuffle=False,
        num_workers=args.num_workers, pin_memory=True, prefetch_factor=2
    )
    return dataloader, len(ds)


def main(args):
    """Main function to process dataset and perform zero-shot classification.

    Args:
        args: Command-line arguments.
    """
    try: 
        subfolder_name = f"{args.model}_{args.dataset}"
        output_dir = Path(args.output_dir) / subfolder_name
        
        Path(f"mixed_images/{args.dataset}_{args.augment}").mkdir(parents=True, exist_ok=True)

        logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")

        os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512'

        device = torch.device(f"cuda:{args.cuda_id}")

        model, _, preprocess = create_model_and_transforms(args.model, pretrained=args.pretrained)
        model.to(device)
        model.eval()

        data_loader, total_samples = _create_dataloaders(args, preprocess)
        tokenizer = get_tokenizer(args.model)
        
        logger.info(f"Dataset '{args.dataset}' loaded successfully, total samples: {total_samples}")
        
        base_dir = 'data/places/train'        
        scene_list = ["street", "kitchen", "office", "living_room", "gymnasium-indoor",   
                    "park", "restaurant", "classroom", "shopping_mall-indoor", "beach"  
            ]
        
        f = h5py.File(f'{output_dir}/openai_text.h5', 'r')
        positive_text_embeddings = f['text_embeddings'][:]
        if args.augment == "mixup":
            negative_texts = ["a photo of geometric shapes", "a black photo", "a mixup image dominated by background", "a photo without visible objects"]
        elif args.augment == "cutmix":
            negative_texts = ["a photo of person", "a black photo", "a cutmix image dominated by background", "a photo without objects"]    
        elif args.augment == "augmix":
            negative_texts = ["a photo of person", "a photo of geometric shapes", "a black photo", "a mixup image dominated by background", "a photo without objects"]
        elif args.augment == "cutout":
            negative_texts = ["a photo of person", "a completely black photo", "a cutout image dominated by background", "a photo without objects"]
        negative_text_embeddings = tokenizer(negative_texts).to(device)
        negative_text_embeddings = model.encode_text(negative_text_embeddings)
        negative_text_embeddings = F.normalize(negative_text_embeddings, dim=-1).detach().cpu().numpy()
        
        text_embeddings = (positive_text_embeddings, negative_text_embeddings)
        
        if args.augment == "mixup":
            mixup_processor = ImageMixup(base_dir=base_dir, clip_model=model, clip_preprocess=preprocess, tokenizer=tokenizer, save_dir=f'mixed_images/{args.dataset}_{args.augment}', text_embeddings=text_embeddings, device=device)
        elif args.augment == "cutmix":
            cutmix_processor = ImageCutMix(base_dir=base_dir, clip_model=model, clip_preprocess=preprocess, tokenizer=tokenizer, save_dir=f'mixed_images/{args.dataset}_{args.augment}', text_embeddings=text_embeddings, device=device)
        elif args.augment == "augmix":
            augmix_processor = ImageAugMix(clip_model=model, clip_preprocess=preprocess, 
                                         tokenizer=tokenizer, save_dir=f'mixed_images/{args.dataset}_{args.augment}', text_embeddings=text_embeddings, device=device)
        elif args.augment == "cutout":
            cutout_processor = ImageCutout(clip_model=model, clip_preprocess=preprocess, 
                                         tokenizer=tokenizer, save_dir=f'mixed_images/{args.dataset}_{args.augment}', text_embeddings=text_embeddings, device=device)

        all_correct_counts = 0
        female_total = 0
        female_correct = 0
        male_total = 0
        male_correct = 0
        cocolabel_stats = defaultdict(lambda: {'female_total': 0, 'female_correct': 0, 'male_total': 0, 'male_correct': 0})

        with torch.inference_mode():
            with h5py.File(f'mixed_images/{args.dataset}_{args.augment}/mean_pos_sims.h5', 'w') as f0: 
                f0.create_dataset('mean_pos_sims', shape=(total_samples, positive_text_embeddings.shape[0]), dtype='float32')
                for batch_idx, (images, labels_info) in enumerate(tqdm(data_loader)):
                    # Get the path of images in the current batch
                    if hasattr(data_loader.dataset, 'samples'):
                        batch_start = batch_idx * args.batch_size
                        batch_end = min(batch_start + args.batch_size, len(data_loader.dataset))
                        image_paths = [data_loader.dataset.samples[i][0] for i in range(batch_start, batch_end)]
                    
                    batch_mean_pos_sims = []
                    
                    for i in range(len(images)):
                        image_path = image_paths[i]  
                        if args.augment == "mixup":
                            pos_sims = mixup_processor.process(image_path, scene_list, image_enable=False)
                        elif args.augment == "cutmix":
                            pos_sims = cutmix_processor.process(image_path, scene_list, image_enable=False)
                        elif args.augment == "augmix":
                            pos_sims = augmix_processor.process(image_path, image_enable=False)
                        elif args.augment == "cutout":
                            pos_sims = cutout_processor.process(image_path, image_enable=False)
                        pos_sims = np.array(pos_sims)
                        
                        # Convert pos_sims to numpy array and stack them
                        mean_pos_sims = np.mean(pos_sims, axis=0)
                        batch_mean_pos_sims.append(mean_pos_sims)

                    batch_mean_pos_sims = np.array(batch_mean_pos_sims)
                    preds = batch_mean_pos_sims.argmax(axis=1)
                    preds = torch.tensor(preds, device=device)

                    # Update counters
                    gender_labels = labels_info[:, 0].long().to(device)
                    cocolabels = labels_info[:, 1:]

                    correct = (preds == gender_labels)

                    all_correct_counts += correct.sum().item()

                    female_mask = (gender_labels == 0)
                    male_mask = (gender_labels == 1)

                    female_total += female_mask.sum().item()
                    female_correct += (correct & female_mask).sum().item()

                    male_total += male_mask.sum().item()
                    male_correct += (correct & male_mask).sum().item()

                    for i in range(len(gender_labels)):
                        sample_cocolabels = cocolabels[i]
                        valid_cocolabels = sample_cocolabels[sample_cocolabels != -1]
                        for cocolabel in valid_cocolabels:
                            cocolabel = int(cocolabel.item())
                            if gender_labels[i] == 0:  # female
                                cocolabel_stats[cocolabel]['female_total'] += 1
                                if correct[i]:
                                    cocolabel_stats[cocolabel]['female_correct'] += 1
                            elif gender_labels[i] == 1:  # male
                                cocolabel_stats[cocolabel]['male_total'] += 1
                                if correct[i]:
                                    cocolabel_stats[cocolabel]['male_correct'] += 1

                    f0['mean_pos_sims'][batch_start:batch_end] = batch_mean_pos_sims
                    del batch_mean_pos_sims, images, labels_info
                    torch.cuda.empty_cache()

        # Calculate and record accuracy
        overall_accuracy = (all_correct_counts / total_samples * 100) if total_samples > 0 else 0
        female_accuracy = (female_correct / female_total * 100) if female_total > 0 else 0
        male_accuracy = (male_correct / male_total * 100) if male_total > 0 else 0

        min_thred = 10
        lowest_female_acc = 100.0
        lowest_male_acc = 100.0
        lowest_female_cocolabel = None
        lowest_male_cocolabel = None

        for cocolabel, stats in cocolabel_stats.items():
            female_total_coco = stats['female_total']
            female_correct_coco = stats['female_correct']
            male_total_coco = stats['male_total']
            male_correct_coco = stats['male_correct']
            
            if female_total_coco > 0:
                female_accuracy_coco = (female_correct_coco / female_total_coco * 100)
                if female_total_coco >= min_thred and female_accuracy_coco < lowest_female_acc:
                    lowest_female_acc = female_accuracy_coco
                    lowest_female_cocolabel = cocolabel
            if male_total_coco > 0:
                male_accuracy_coco = (male_correct_coco / male_total_coco * 100)
                if male_total_coco >= min_thred and male_accuracy_coco < lowest_male_acc:
                    lowest_male_acc = male_accuracy_coco
                    lowest_male_cocolabel = cocolabel

        logger.info(f"Saved mean_pos_sims to {args.dataset}_mean_pos_sims.h5")
        logger.info(f"Overall accuracy: {overall_accuracy:.2f}%")
        logger.info(f"Female accuracy: {female_accuracy:.2f}%")
        logger.info(f"Male accuracy: {male_accuracy:.2f}%")
        if lowest_female_cocolabel is not None:
            logger.info(f"Lowest female accuracy: {lowest_female_acc:.2f}% (cocolabel: {lowest_female_cocolabel})")
        else:
            logger.info("No cocolabel with sufficient female samples for lowest accuracy calculation.")
        if lowest_male_cocolabel is not None:
            logger.info(f"Lowest male accuracy: {lowest_male_acc:.2f}% (cocolabel: {lowest_male_cocolabel})")
        else:
            logger.info("No cocolabel with sufficient male samples for lowest accuracy calculation.")

    except Exception as e:
        logger.exception(f"Execution failed: {e}")

if __name__ == "__main__":
    parser = get_parser_info()
    args = parser.parse_args()
    set_seed(42)
    main(args)
