import torch
from torch.nn import functional as F
from tqdm import tqdm
from factory import process_batch
from collections import defaultdict
import sys, os
project_root = os.path.abspath(os.getcwd())
sys.path.append(project_root)
from data.imagenet_a import thousand_k_to_200  # label mapping
from data.imagenet_w.watermark_transform import CARTON_CLASS_INDEX
import time


def classify_waterbirds(data_dict, text_embeddings, place_embeddings, args, device, logger):
    """
    Classify Waterbirds dataset and compute overall and group (label × place) accuracy.

    Args:
        data_dict (dict): Dictionary containing 'labels_info', shape=(N, 2). The first column is label (0=landbird,1=waterbird),
                        the second column is place (0=land,1=water).
        text_embeddings (torch.Tensor): Precomputed text embeddings, shape=(C, D).
        place_embeddings: (Optional) Scene embeddings, not used when None.
        args: Namespace containing batch_size, MAX_K, lam, etc.
        device: torch.device.
        logger: Logger for printing information.
    """

    num_samples = data_dict["labels_info"].shape[0]

    # Initialize overall & group counters
    total_correct = 0
    group_correct = {g: 0 for g in range(4)}  # 0: LB-on-L, 1: LB-on-W, 2: WB-on-L, 3: WB-on-W
    group_total   = {g: 0 for g in range(4)}

    # Iterate through all samples
    for start_idx in tqdm(range(0, num_samples, args.batch_size), desc="Classifying waterbirds"):
        end_idx = min(start_idx + args.batch_size, num_samples)

        # Extract ground-truth labels & places
        batch_info = data_dict["labels_info"][start_idx:end_idx]
        labels = torch.tensor(batch_info[:, 0], dtype=torch.long, device=device)
        places = torch.tensor(batch_info[:, 1], dtype=torch.long, device=device)

        # start = time.time()
        # Calculate TDE base scores & counterfactual scores
        tde_base, tde_cf = process_batch(
            data_dict, text_embeddings, place_embeddings,
            range(start_idx,end_idx), args, args.MAX_K
        )
        # end = time.time()
        # print(f"end-start={(end-start)*1000} ms")

        # Weighted fusion for final scores
        if text_embeddings.shape[0] > args.MAX_K:
            tde_scores = torch.empty_like(tde_base)
            tde_scores[:, :args.MAX_K] = (
                (1 - args.lam) * tde_base[:, :args.MAX_K] +
                args.lam       * tde_cf[:,   :args.MAX_K]
            )
            tde_scores[:, args.MAX_K:] = tde_base[:, args.MAX_K:]
        else:
            tde_scores = (1 - args.lam) * tde_base + args.lam * tde_cf

        # Predictions
        predictions = torch.argmax(tde_scores, dim=-1)

        # Accumulate overall correct count
        total_correct += (predictions == labels).sum().item()

        # Group statistics
        for g in range(4):
            # g = label*2 + place
            mask = (labels == (g // 2)) & (places == (g % 2))
            count_in_group = mask.sum().item()
            group_total[g] += count_in_group
            if count_in_group > 0:
                group_correct[g] += (predictions[mask] == labels[mask]).sum().item()

    # Calculate & print metrics
    overall_acc = total_correct / num_samples * 100
    logger.info(f"Overall accuracy: {overall_acc:.2f}%")

    group_names = {
        0: "Landbird-on-Land",
        1: "Landbird-on-Water",
        2: "Waterbird-on-Land",
        3: "Waterbird-on-Water"
    }
    for g in range(4):
        if group_total[g] > 0:
            acc_g = group_correct[g] / group_total[g] * 100
        else:
            acc_g = 0.0
        logger.info(f"Group {g} ({group_names[g]}) accuracy: {acc_g:.2f}%")
        
    return {
        "overall_accuracy": overall_acc,
        "group_accuracies": {g: acc_g for g in range(4)}
    }

def classify_urbancars(data_dict, text_embeddings, place_embeddings, args, device, logger):
    """
    Classify Urbancars dataset and compute:
        - overall_accuracy
        - four group accuracies (ID, BG, CoObj, BG_CoObj)
        - GAP relative to ID group

    Group definitions (labels_info[:, :3] are [label, place, coobj]):
        ID group      : (0,0,0) & (1,1,1)
        BG group      : (0,1,0) & (1,0,1)
        CoObj group   : (0,0,1) & (1,1,0)
        BG_CoObj group: (1,0,0) & (0,1,1)
        
    Args:
        data_dict (dict): Dictionary containing dataset information.
        text_embeddings (torch.Tensor): Precomputed text embeddings.
        place_embeddings: Scene embeddings (optional).
        args: Arguments including batch_size, MAX_K, and lam.
        device: Computation device.
        logger: Logger for output.
        
    Returns:
        dict: Dictionary containing accuracy metrics.
    """
    num_samples = data_dict["labels_info"].shape[0]

    # Initialize overall & group counters
    total_correct = 0
    id_correct = id_total = 0
    bg_correct = bg_total = 0
    coobj_correct = coobj_total = 0
    bg_coobj_correct = bg_coobj_total = 0

    for start_idx in tqdm(range(0, num_samples, args.batch_size), desc="Classifying Urbancars"):
        end_idx = min(start_idx + args.batch_size, num_samples)
        batch_info = data_dict["labels_info"][start_idx:end_idx]

        labels = torch.tensor(batch_info[:, 0], dtype=torch.long, device=device)
        places = torch.tensor(batch_info[:, 1], dtype=torch.long, device=device)
        coobjs = torch.tensor(batch_info[:, 2], dtype=torch.long, device=device)

        # TDE inference
        tde_base, tde_cf = process_batch(
            data_dict, text_embeddings, place_embeddings,
            range(start_idx,end_idx), args, args.MAX_K
        )

        # Weighted fusion
        if text_embeddings.shape[0] > args.MAX_K:
            tde_score = torch.empty_like(tde_base)
            tde_score[:, :args.MAX_K] = (
                (1 - args.lam) * tde_base[:, :args.MAX_K] +
                args.lam       * tde_cf[:,   :args.MAX_K]
            )
            tde_score[:, args.MAX_K:] = tde_base[:, args.MAX_K:]
        else:
            tde_score = (1 - args.lam) * tde_base + args.lam * tde_cf

        preds = torch.argmax(tde_score, dim=-1)
        total_correct += (preds == labels).sum().item()

        # ID group mask & accumulation
        id_mask = ((labels == 0) & (places == 0) & (coobjs == 0)) | \
                    ((labels == 1) & (places == 1) & (coobjs == 1))
        id_total   += id_mask.sum().item()
        id_correct += (preds[id_mask] == labels[id_mask]).sum().item()

        # BG group
        bg_mask = ((labels == 0) & (places == 1) & (coobjs == 0)) | \
                    ((labels == 1) & (places == 0) & (coobjs == 1))
        bg_total   += bg_mask.sum().item()
        bg_correct += (preds[bg_mask] == labels[bg_mask]).sum().item()

        # CoObj group
        coobj_mask = ((labels == 0) & (places == 0) & (coobjs == 1)) | \
                    ((labels == 1) & (places == 1) & (coobjs == 0))
        coobj_total   += coobj_mask.sum().item()
        coobj_correct += (preds[coobj_mask] == labels[coobj_mask]).sum().item()

        # BG_CoObj group
        bg_coobj_mask = ((labels == 1) & (places == 0) & (coobjs == 0)) | \
                        ((labels == 0) & (places == 1) & (coobjs == 1))
        bg_coobj_total   += bg_coobj_mask.sum().item()
        bg_coobj_correct += (preds[bg_coobj_mask] == labels[bg_coobj_mask]).sum().item()

    # Calculate metrics
    # overall_acc     = total_correct / num_samples * 100
    id_acc          = id_correct    / id_total    * 100 if id_total    > 0 else 0.0
    bg_acc          = bg_correct    / bg_total    * 100 if bg_total    > 0 else 0.0
    coobj_acc       = coobj_correct / coobj_total * 100 if coobj_total > 0 else 0.0
    bg_coobj_acc    = bg_coobj_correct / bg_coobj_total * 100 if bg_coobj_total > 0 else 0.0

    # GAP metrics (relative to ID group)
    bg_gap       = id_acc - bg_acc
    coobj_gap    = id_acc - coobj_acc
    bg_coobj_gap = id_acc - bg_coobj_acc

    # Log output
    # logger.info(f"Overall accuracy:      {overall_acc:.2f}%")
    logger.info(f"ID group accuracy:     {id_acc:.2f}%")
    logger.info(f"BG group accuracy:     {bg_acc:.2f}% (GAP: {bg_gap:.2f}%)")
    logger.info(f"CoObj group accuracy:  {coobj_acc:.2f}% (GAP: {coobj_gap:.2f}%)")
    logger.info(f"BG_CoObj group acc.:   {bg_coobj_acc:.2f}% (GAP: {bg_coobj_gap:.2f}%)")

    return {
        # "overall_accuracy": overall_acc,
        "id_accuracy": id_acc,
        "bg_accuracy": bg_acc,
        "coobj_accuracy": coobj_acc,
        "bg_coobj_accuracy": bg_coobj_acc,
        "bg_gap": bg_gap,
        "coobj_gap": coobj_gap,
        "bg_coobj_gap": bg_coobj_gap
    }

def classify_cocogb(data_dict, text_embeddings, place_embeddings, args, device, logger):
    """
    Classify CocoGB dataset using TDE inference and compute:
        - overall_accuracy
        - female_accuracy, male_accuracy
        - accuracy for each co‑label (female/male/overall)
        - lowest female and lowest male accuracy and corresponding co‑label

    Args:
        data_dict (dict): Dictionary with labels_info where [:, 0] is gender label (0=female,1=male),
                        and [:, 1:] are multi-label co‑labels (-1 for missing values).
        text_embeddings (torch.Tensor): Precomputed text embeddings.
        place_embeddings: Scene embeddings (optional).
        args: Arguments including batch_size, MAX_K, and lam.
        device: Computation device.
        logger: Logger for output.
        
    Returns:
        dict: Dictionary containing accuracy metrics.
    """
    labels_info = data_dict["labels_info"]
    total_samples = labels_info.shape[0]
    MIN_SAMPLES = 10
    
    # Overall & gender statistics
    total_correct = 0
    female_total = female_correct = 0
    male_total   = male_correct   = 0

    # Statistics for each co‑label
    cocolabel_stats = defaultdict(lambda: {
        'female_total': 0, 'female_correct': 0,
        'male_total': 0,   'male_correct':   0
    })

    # Batch processing loop
    for start_idx in tqdm(range(0, total_samples, args.batch_size), desc="Classifying CocoGB"):
        end_idx = min(start_idx + args.batch_size, total_samples)
        batch_info = labels_info[start_idx:end_idx]

        # Gender labels
        labels = torch.tensor(batch_info[:, 0], dtype=torch.long, device=device)
        # co‑labels
        cocolabels = batch_info[:, 1:]  # numpy array

        # TDE inference
        tde_base, tde_cf = process_batch(
            data_dict, text_embeddings, place_embeddings,
            range(start_idx,end_idx), args, args.MAX_K
        )
        # Weighted fusion
        if text_embeddings.shape[0] > args.MAX_K:
            tde_score = torch.empty_like(tde_base)
            tde_score[:, :args.MAX_K] = (
                (1 - args.lam) * tde_base[:, :args.MAX_K] +
                args.lam       * tde_cf[:,   :args.MAX_K]
            )
            tde_score[:, args.MAX_K:] = tde_base[:, args.MAX_K:]
        else:
            tde_score = (1 - args.lam) * tde_base + args.lam * tde_cf

        preds = torch.argmax(tde_score, dim=-1)

        # Accumulate overall correct count
        total_correct += (preds == labels).sum().item()

        # Per-sample gender & co‑label statistics
        for i in range(end_idx - start_idx):
            label = labels[i].item()         # 0 or 1
            is_correct = (preds[i].item() == label)

            # Gender statistics
            if label == 0:
                female_total += 1
                if is_correct:
                    female_correct += 1
            else:
                male_total   += 1
                if is_correct:
                    male_correct   += 1

            # co‑label statistics
            valid_coco = cocolabels[i][cocolabels[i] != -1]
            for coco in valid_coco:
                coco = int(coco)
                stats = cocolabel_stats[coco]
                if label == 0:
                    stats['female_total'] += 1
                    if is_correct:
                        stats['female_correct'] += 1
                else:
                    stats['male_total']   += 1
                    if is_correct:
                        stats['male_correct']   += 1

    # Calculate overall & gender accuracy
    overall_acc = total_correct / total_samples * 100
    female_acc  = (female_correct / female_total * 100) if female_total > 0 else 0.0
    male_acc    = (male_correct   / male_total   * 100) if male_total   > 0 else 0.0

    # Calculate accuracy for each co‑label and find the lowest
    cocolabel_accuracies = {}
    lowest_female_acc = lowest_male_acc = 100.0
    worst_female_coco = worst_male_coco = None

    for coco, stats in cocolabel_stats.items():
        f_tot = stats['female_total']
        m_tot = stats['male_total']
        f_cor = stats['female_correct']
        m_cor = stats['male_correct']
        tot   = f_tot + m_tot
        cor   = f_cor + m_cor

        f_acc = (f_cor / f_tot * 100) if f_tot >= MIN_SAMPLES else None
        m_acc = (m_cor / m_tot * 100) if m_tot >= MIN_SAMPLES else None
        tot_acc = (cor / tot * 100)   if tot   >= MIN_SAMPLES else None

        # Update worst records
        if f_acc is not None and f_acc < lowest_female_acc:
            lowest_female_acc = f_acc
            worst_female_coco = coco
        if m_acc is not None and m_acc < lowest_male_acc:
            lowest_male_acc = m_acc
            worst_male_coco = coco

        cocolabel_accuracies[coco] = {
            'female_accuracy':   f_acc,
            'male_accuracy':     m_acc,
            'overall_accuracy': tot_acc,
            'female_total':      f_tot,
            'male_total':        m_tot
        }

    # Log output
    logger.info(f"Overall accuracy: {overall_acc:.2f}%")
    logger.info(f"Female accuracy: {female_acc:.2f}% ({female_correct}/{female_total})")
    logger.info(f"Male accuracy:   {male_acc:.2f}% ({male_correct}/{male_total})")
    logger.info(f"Worst female co‑label: {worst_female_coco} @ {lowest_female_acc:.2f}%")
    logger.info(f"Worst male co‑label:   {worst_male_coco} @ {lowest_male_acc:.2f}%")

    return {
        'overall_accuracy': overall_acc,
        'female_accuracy':  female_acc,
        'male_accuracy':    male_acc,
        'cocolabel_accuracies': cocolabel_accuracies,
        'worst_female':     (worst_female_coco, lowest_female_acc),
        'worst_male':       (worst_male_coco,   lowest_male_acc)
    }


def classify_nico(data_dict, text_embeddings, place_embeddings, args, device, logger):
    """
    Classify NICO dataset and compute:
        - overall accuracy
        - per-class accuracy
        - per-context accuracy within each class
        - best and worst context for each class

    Args:
        data_dict (dict): Dictionary containing 'labels_info', shape=(N, 2). The first column is the class label,
                         the second column is the context label.
        text_embeddings (torch.Tensor): Precomputed text embeddings, shape=(C, D).
        place_embeddings: Scene embeddings, can be None.
        args: Namespace containing batch_size, MAX_K, lam, etc.
        device: torch.device.
        logger: Logger for printing information.
    """
    num_samples = data_dict["labels_info"].shape[0]
    num_classes = int(torch.tensor(data_dict["labels_info"][:, 0]).max().item()) + 1
    
    # Initialize counters
    class_correct_counts = {i: 0 for i in range(num_classes)}
    class_total_counts = {i: 0 for i in range(num_classes)}
    context_correct_counts = {}
    context_total_counts = {}
    
    for class_idx in range(num_classes):
        context_correct_counts[class_idx] = {}
        context_total_counts[class_idx] = {}
    
    # Batch process data
    for start_idx in tqdm(range(0, num_samples, args.batch_size), desc="Classifying NICO"):
        end_idx = min(start_idx + args.batch_size, num_samples)
        
        # Get label information
        batch_info = data_dict["labels_info"][start_idx:end_idx]
        class_labels = torch.tensor(batch_info[:, 0], dtype=torch.long, device=device)
        context_labels = torch.tensor(batch_info[:, 1], dtype=torch.long, device=device)
        
        # Use TDE for inference
        tde_base, tde_cf = process_batch(
            data_dict, text_embeddings, place_embeddings,
            range(start_idx, end_idx), args, args.MAX_K
        )
        
        # Weighted fusion for final scores
        if text_embeddings.shape[0] > args.MAX_K:
            tde_scores = torch.empty_like(tde_base)
            tde_scores[:, :args.MAX_K] = (
                (1 - args.lam) * tde_base[:, :args.MAX_K] +
                args.lam * tde_cf[:, :args.MAX_K]
            )
            tde_scores[:, args.MAX_K:] = tde_base[:, args.MAX_K:]
        else:
            tde_scores = (1 - args.lam) * tde_base + args.lam * tde_cf
        
        # Predictions
        predictions = torch.argmax(tde_scores, dim=-1)
        
        # Calculate accuracy for each class
        for class_idx in range(num_classes):
            class_mask = (class_labels == class_idx)
            if torch.sum(class_mask) > 0:
                correct = torch.sum((predictions[class_mask] == class_labels[class_mask])).item()
                total = torch.sum(class_mask).item()
                class_correct_counts[class_idx] += correct
                class_total_counts[class_idx] += total
                
                # Calculate accuracy for each context within this class
                for i, (label, context, pred) in enumerate(zip(class_labels, context_labels, predictions)):
                    if label.item() == class_idx:
                        ctx_idx = context.item()
                        if ctx_idx not in context_correct_counts[class_idx]:
                            context_correct_counts[class_idx][ctx_idx] = 0
                            context_total_counts[class_idx][ctx_idx] = 0
                        
                        context_total_counts[class_idx][ctx_idx] += 1
                        if pred.item() == label.item():
                            context_correct_counts[class_idx][ctx_idx] += 1
    
    # Calculate overall accuracy
    total_correct = sum(class_correct_counts.values())
    total_samples = sum(class_total_counts.values())
    overall_accuracy = (total_correct / total_samples) * 100 if total_samples > 0 else 0
    
    # Calculate per-class accuracy
    class_accuracies = {}
    for class_idx in range(num_classes):
        if class_total_counts[class_idx] > 0:
            class_accuracies[class_idx] = (class_correct_counts[class_idx] / class_total_counts[class_idx]) * 100
        else:
            class_accuracies[class_idx] = 0
    
    # Calculate per-context accuracy and find best/worst contexts
    context_accuracies = {}
    best_contexts = {}
    worst_contexts = {}
    
    for class_idx in range(num_classes):
        if class_idx not in context_correct_counts or not context_correct_counts[class_idx]:
            continue
            
        context_accuracies[class_idx] = {}
        best_acc = -1
        worst_acc = 101  # Initial value greater than 100
        best_ctx = worst_ctx = None
        
        for ctx_idx in context_correct_counts[class_idx]:
            if context_total_counts[class_idx][ctx_idx] > 0:
                acc = (context_correct_counts[class_idx][ctx_idx] / context_total_counts[class_idx][ctx_idx]) * 100
                context_accuracies[class_idx][ctx_idx] = acc
                
                if acc > best_acc:
                    best_acc = acc
                    best_ctx = ctx_idx
                if acc < worst_acc:
                    worst_acc = acc
                    worst_ctx = ctx_idx
        
        best_contexts[class_idx] = (best_ctx, best_acc)
        worst_contexts[class_idx] = (worst_ctx, worst_acc)
    
    # Log results
    logger.info(f"Overall accuracy: {overall_accuracy:.2f}%")
    
    # Sort classes by accuracy
    sorted_classes = sorted([(idx, acc) for idx, acc in class_accuracies.items()], key=lambda x: x[1], reverse=True)
    
    logger.info("\nClass-wise accuracies:")
    for class_idx, acc in sorted_classes:
        logger.info(f"Class {class_idx}: {acc:.2f}% ({class_correct_counts[class_idx]}/{class_total_counts[class_idx]})")
    
    logger.info("\nContext-wise accuracies:")
    for class_idx in sorted([idx for idx in context_accuracies.keys()]):
        if class_idx not in context_accuracies:
            continue
            
        logger.info(f"\nClass {class_idx}:")
        # Sort contexts by accuracy
        sorted_contexts = sorted([(ctx, acc) for ctx, acc in context_accuracies[class_idx].items()], 
                                key=lambda x: x[1], reverse=True)
        
        for ctx_idx, acc in sorted_contexts:
            logger.info(f"  Context {ctx_idx}: {acc:.2f}% ({context_correct_counts[class_idx][ctx_idx]}/{context_total_counts[class_idx][ctx_idx]})")
        
        best_ctx, best_acc = best_contexts[class_idx]
        worst_ctx, worst_acc = worst_contexts[class_idx]
        logger.info(f"  Best context: {best_ctx} ({best_acc:.2f}%)")
        logger.info(f"  Worst context: {worst_ctx} ({worst_acc:.2f}%)")

    output_file = f"./output/NICO/{args.model}_{args.dataset}_{args.scene_type}_nico_results.txt"
    # Save results to txt file
    if output_file:
        try:
            with open(output_file, 'w', encoding='utf-8') as f:
                # Write overall accuracy
                f.write(f"NICO Dataset Classification Results\n")
                f.write(f"================================\n\n")
                f.write(f"Overall accuracy: {overall_accuracy:.2f}%\n\n")
                
                # Write per-class accuracy
                f.write("Class-wise accuracies:\n")
                f.write("====================\n")
                for class_idx, acc in sorted_classes:
                    f.write(f"Class {class_idx}: {acc:.2f}% ({class_correct_counts[class_idx]}/{class_total_counts[class_idx]})\n")
                
                # Write per-context accuracy for each class
                f.write("\nContext-wise accuracies for each class:\n")
                f.write("===================================\n")
                for class_idx in sorted([idx for idx in context_accuracies.keys()]):
                    if class_idx not in context_accuracies:
                        continue
                    
                    f.write(f"\nClass {class_idx}:\n")
                    # Sort contexts by accuracy
                    sorted_contexts = sorted([(ctx, acc) for ctx, acc in context_accuracies[class_idx].items()], 
                                            key=lambda x: x[1], reverse=True)
                    
                    for ctx_idx, acc in sorted_contexts:
                        f.write(f"  Context {ctx_idx}: {acc:.2f}% ({context_correct_counts[class_idx][ctx_idx]}/{context_total_counts[class_idx][ctx_idx]})\n")
                    
                    best_ctx, best_acc = best_contexts[class_idx]
                    worst_ctx, worst_acc = worst_contexts[class_idx]
                    f.write(f"  Best context: {best_ctx} ({best_acc:.2f}%)\n")
                    f.write(f"  Worst context: {worst_ctx} ({worst_acc:.2f}%)\n")
                
                logger.info(f"Results saved to {output_file}")
        except Exception as e:
            logger.error(f"Error saving results to file: {e}")
    
    return {
        "overall_accuracy": overall_accuracy,
        "class_accuracies": class_accuracies,
        "context_accuracies": context_accuracies,
        "best_contexts": best_contexts,
        "worst_contexts": worst_contexts
    }