import argparse
import os
import sys
import random
import numpy as np
import torch
import h5py
from torch.nn import functional as F
from pathlib import Path
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from tqdm import tqdm
import logging
from mixup import ImageMixup
from cutmix import ImageCutMix
from augmix import ImageAugMix
from cutout import ImageCutout

# Use relative import path instead of absolute path
project_root = os.path.abspath(os.getcwd())
sys.path.append(project_root)
from data.waterbirds import WaterbirdsDataset

from CLIP_utils.factory import create_model_and_transforms, get_tokenizer


# Set HF mirror URL to accelerate model downloads
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"

# Configure global logger
logger = logging.getLogger(__name__)

def set_seed(seed):
    """Set global random seed for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

def get_parser_info():
    """Configure command-line argument parser."""
    parser = argparse.ArgumentParser(
        description="Load and process dataset for CLIP zero-shot classification.",
        add_help=True
    )
    # Model parameters
    parser.add_argument("--model", default="ViT-B-16", type=str, help="Name of the model to use.")
    parser.add_argument("--pretrained", default="laion2b_s34b_b88k", type=str, help="Pretrained weights to load.")
    parser.add_argument("--batch_size", type=int, default=32, help="Batch size for data loading.")
    # Data parameters
    parser.add_argument("--num_workers", type=int, default=4, help="Number of workers for data loading.")
    parser.add_argument("--dataset", default="waterbirds", type=str, help="Name of the dataset.")
    parser.add_argument("--data_root", type=str, default="./data", help="Path to the dataset root directory.")
    # Output parameters
    parser.add_argument("--output_dir", type=str, default="./results", help="Path to save output results.")
    parser.add_argument("--cuda_id", type=str, default="3", help="Device to run the model on.")
    parser.add_argument("--augment", type=str, default="cutout", help="Augmentation method to use.")
    return parser

def _create_dataloaders(args, preprocess):
    """Create a DataLoader for the specified dataset.

    Args:
        args: Command-line arguments containing dataset and loader settings.
        preprocess: Image preprocessing pipeline.

    Returns:
        Tuple of DataLoader instance and total number of samples.
    """
    dataset_map = {
        "waterbirds": lambda: WaterbirdsDataset(root=os.path.join(args.data_root, "waterbirds"), split="test", transform=preprocess),
    }
    data_path = os.path.join(args.data_root, args.dataset)
    # if not os.path.exists(data_path):
    #     raise FileNotFoundError(f"Dataset path not found: {data_path}")
    ds = dataset_map.get(args.dataset, lambda: ImageFolder(root=data_path, transform=preprocess))()
    dataloader = DataLoader(
        ds, batch_size=args.batch_size, shuffle=False,
        num_workers=args.num_workers, pin_memory=True, prefetch_factor=2
    )
    return dataloader, len(ds)

def eval_waterbirds(group_correct_counts, group_total_counts, label_info, pred):
    """Evaluate predictions for Waterbirds dataset by computing correct and total counts for each group.
    
    Args:
        group_correct_counts (dict): Dictionary tracking correct predictions per group.
        group_total_counts (dict): Dictionary tracking total samples per group.
        label_info (tensor or array): Array with shape (N, 2) containing labels and places.
        pred (tensor or array): Predicted labels.
    
    Returns:
        tuple: Updated group_correct_counts and group_total_counts.
    """
    # Ensure input data is in numpy array format
    labels = label_info[:,0].numpy() if torch.is_tensor(label_info) else label_info[:,0]
    places = label_info[:,1].numpy() if torch.is_tensor(label_info) else label_info[:,1]
    pred = pred.numpy() if torch.is_tensor(pred) else pred

    for group in range(4):
        # Create mask for group identification
        mask = (labels == (group // 2)) & (places == (group % 2))
        
        # Count correct predictions
        correct_count = np.sum(pred[mask] == labels[mask])
        total_count = np.sum(mask)
        
        # Update counters
        group_correct_counts[group] += int(correct_count)
        group_total_counts[group] += int(total_count)
    
    return group_correct_counts, group_total_counts

def main(args):
    """Main function to process dataset and perform zero-shot classification.

    Args:
        args: Command-line arguments.
    """
    try: 
        subfolder_name = f"{args.model}_{args.dataset}"
        output_dir = Path(args.output_dir) / subfolder_name
        
        Path(f'mixed_images/{args.dataset}_{args.augment}').mkdir(parents=True, exist_ok=True)

        logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")

        os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512'

        device = torch.device(f"cuda:{args.cuda_id}")

        model, _, preprocess = create_model_and_transforms(args.model, pretrained=args.pretrained)
        model.to(device)
        model.eval()

        data_loader, total_samples = _create_dataloaders(args, preprocess)
        tokenizer = get_tokenizer(args.model)
        
        logger.info(f"Dataset '{args.dataset}' loaded successfully, total samples: {total_samples}")
        
        base_dir = 'data/places/train'        
        scene_list = [
            'beach', 'coast', 'lake-natural', 'marsh', 'swamp',
            'river','forest-broadleaf', 'field-wild', 'mountain','pasture'
            ]
        
        f = h5py.File(f'{output_dir}/openai_text.h5', 'r')
        positive_text_embeddings = f['text_embeddings'][:]
        if args.augment == "mixup":
            negative_texts = ["a photo of bird","a photo of geometric shapes", "a black photo","a mixup image dominated by background", "a photo without visible objects"]
        elif args.augment == "cutmix":
            negative_texts = ["a photo of bird", "a black photo","a cutmix image dominated by background", "a photo without objects"]
        elif args.augment == "augmix":
            negative_texts = ["a photo of bird", "a photo of geometric shapes", "a black photo","a mixup image dominated by background", "a photo without objects"]
        elif args.augment == "cutout":
            negative_texts = ["a photo of bird", "a completely black photo","a cutout image dominated by background", "a photo without objects"]
        negative_text_embeddings = tokenizer(negative_texts).to(device)
        negative_text_embeddings = model.encode_text(negative_text_embeddings)
        negative_text_embeddings = F.normalize(negative_text_embeddings, dim=-1).detach().cpu().numpy()
        
        text_embeddings = (positive_text_embeddings, negative_text_embeddings)
        
        if args.augment == "mixup":
            mixup_processor = ImageMixup(base_dir=base_dir, clip_model=model, clip_preprocess=preprocess, 
                                        tokenizer=tokenizer, save_dir=f'mixed_images/{args.dataset}_{args.augment}', text_embeddings=text_embeddings, device=device)
        elif args.augment == "cutmix":
            cutmix_processor = ImageCutMix(base_dir=base_dir, clip_model=model, clip_preprocess=preprocess, 
                                        tokenizer=tokenizer, save_dir=f'mixed_images/{args.dataset}_{args.augment}', text_embeddings=text_embeddings, device=device)
        elif args.augment == "augmix":
            augmix_processor = ImageAugMix(clip_model=model, clip_preprocess=preprocess, 
                                        tokenizer=tokenizer, save_dir=f'mixed_images/{args.dataset}_{args.augment}', text_embeddings=text_embeddings, device=device)
        elif args.augment == "cutout":
            cutout_processor = ImageCutout(clip_model=model, clip_preprocess=preprocess, 
                                        tokenizer=tokenizer, save_dir=f'mixed_images/{args.dataset}_{args.augment}', text_embeddings=text_embeddings, device=device)

        # Initialize group counters for tracking accuracy
        group_correct_counts = {i: 0 for i in range(4)}
        group_total_counts = {i: 0 for i in range(4)}
        

        
        with torch.inference_mode():
            with h5py.File(f'mixed_images/{args.dataset}_{args.augment}/mean_pos_sims.h5', 'w') as f0: 
                f0.create_dataset('mean_pos_sims', shape=(total_samples, positive_text_embeddings.shape[0]), dtype='float32')
                for batch_idx, (images, labels_info) in enumerate(tqdm(data_loader)):
                    # Get the paths of images in the current batch
                    if hasattr(data_loader.dataset, 'samples'):
                        batch_start = batch_idx * args.batch_size
                        batch_end = min(batch_start + args.batch_size, len(data_loader.dataset))
                        image_paths = [data_loader.dataset.samples[i][0] for i in range(batch_start, batch_end)]
                    
                    batch_mean_pos_sims = []
                    
                    for i in range(len(images)):
                        image_path = image_paths[i]  
                        if args.augment == "mixup":
                            pos_sims = mixup_processor.process(image_path, scene_list,image_enable=False)
                        elif args.augment == "cutmix":
                            pos_sims = cutmix_processor.process(image_path, scene_list,image_enable=False)
                        elif args.augment == "augmix":
                            pos_sims = augmix_processor.process(image_path,image_enable=False)
                        elif args.augment == "cutout":
                            pos_sims = cutout_processor.process(image_path,image_enable=False)
                        pos_sims = np.array(pos_sims)
                        
                        # Convert pos_sims to numpy array and calculate mean
                        mean_pos_sims = np.mean(pos_sims, axis=0)
                        batch_mean_pos_sims.append(mean_pos_sims)
                    
                    batch_mean_pos_sims = np.array(batch_mean_pos_sims)  # Convert list to NumPy array
                    preds = batch_mean_pos_sims.argmax(axis=1)  # Get predictions from similarity scores
                    group_correct_counts, group_total_counts = eval_waterbirds(group_correct_counts, group_total_counts, labels_info, preds)
                    
                    # Add current batch's mean_pos_sims to the total list
                    f0['mean_pos_sims'][batch_start:batch_end] = batch_mean_pos_sims
                    # logger.info(f"Batch {batch_idx} accuracy: {sum(group_correct_counts.values()) / sum(group_total_counts.values())}")
                    del batch_mean_pos_sims, images, labels_info
                    torch.cuda.empty_cache()   
            
            # Calculate overall accuracy
            overall_correct = sum(group_correct_counts.values())
            overall_total = sum(group_total_counts.values())
            overall_accuracy = (overall_correct / overall_total) * 100 if overall_total > 0 else 0

            # Calculate group accuracies
            group_accuracies = {}
            for group in range(4):
                if group_total_counts[group] > 0:
                    group_accuracies[group] = (group_correct_counts[group] / group_total_counts[group]) * 100
                else:
                    group_accuracies[group] = 0

            logger.info(f"Saved mean_pos_sims to {args.dataset}_mean_pos_sims.h5")
            logger.info(f"Overall accuracy: {overall_accuracy}")
            logger.info(f"Group accuracies: {group_accuracies}")

    except Exception as e:
        logger.exception(f"Execution failed: {e}")

if __name__ == "__main__":
    parser = get_parser_info()
    args = parser.parse_args()
    set_seed(42)
    main(args)
