# forward_forward/evaluation/ff_evaluation.py (Enhanced with Multi-Mode Class Grouping)

import torch
import numpy as np
from sklearn.metrics import precision_recall_fscore_support

import wandb
from typing import Optional, Union, List, Dict

from forward_forward.ff_layer import FFLayer
from forward_forward.models.model_factory import SkipConnection
from forward_forward.models.layers.class_grouping import ClassGroupingMode


def map_grouped_predictions_to_original(
    grouped_predictions: torch.Tensor,
    group_mapping: Dict[int, List[int]],
    strategy: str = "first"
) -> torch.Tensor:
    """
    Map grouped predictions back to original class space.
    
    Args:
        grouped_predictions: Predictions in grouped class space (B,)
        group_mapping: Mapping from group_id to list of original classes
        strategy: Strategy for mapping grouped classes back to original classes
                 - "first": Use the first class in the group
                 - "random": Randomly sample from the group
                 - "uniform": Return probabilities uniformly distributed over group classes
    
    Returns:
        torch.Tensor: Predictions in original class space (B,)
    """
    device = grouped_predictions.device
    original_predictions = torch.zeros_like(grouped_predictions)
    
    for batch_idx, group_pred in enumerate(grouped_predictions):
        group_id = group_pred.item()
        
        if group_id in group_mapping:
            original_classes = group_mapping[group_id]
            
            if strategy == "first":
                original_predictions[batch_idx] = original_classes[0]
            elif strategy == "random":
                chosen_class = np.random.choice(original_classes)
                original_predictions[batch_idx] = chosen_class
            elif strategy == "uniform":
                # For evaluation, we'll just use the first class
                original_predictions[batch_idx] = original_classes[0]
            else:
                raise ValueError(f"Unknown mapping strategy: {strategy}")
        else:
            # Group ID corresponds to original class (ungrouped)
            original_predictions[batch_idx] = group_id
    
    return original_predictions

def _confusion_matrix(y_true, y_pred, num_classes):
    # y_true, y_pred: (N,) tensors with class indices
    mask = (y_true >= 0) & (y_true < num_classes)  # optional filtering

    y_true = y_true[mask]
    y_pred = y_pred[mask]

    # Compute linear indices for bincount
    indices = y_true * num_classes + y_pred
    cm = torch.bincount(indices, minlength=num_classes*num_classes).reshape(num_classes, num_classes)

    return cm

def evaluate_with_class_grouping(
    model: torch.nn.Module,
    dataloader: torch.utils.data.DataLoader,
    device: str,
    split: str,
    step: Optional[int] = None,
    num_epochs: Optional[int] = None,
    wandb_enabled: bool = False,
    wandb_prefix: str = "",
    k: Optional[int] = None,
    block_name: Optional[Union[str, List[str]]] = None,
    precision: Optional[bool] = False,
    recall: Optional[bool] = False,
    f1: Optional[bool] = False,
    confusion_matrix: Optional[bool] = True,
    dataset_name: str = "CIFAR-10",
    original_num_classes: int = 10,
) -> dict:
    """
    Enhanced evaluation function that handles both class grouping modes.
    
    Args:
        model: The Forward-Forward model
        dataloader: Data loader for evaluation
        device: Device to run evaluation on
        split: Split name (train/val/test)
        step: Current training step
        num_epochs: Total number of epochs
        wandb_enabled: Whether to log to wandb
        wandb_prefix: Prefix for wandb logging
        k: Top-k accuracy (not implemented yet)
        block_name: Which blocks to evaluate
        precision: Whether to compute precision
        recall: Whether to compute recall
        f1: Whether to compute F1 score
        confusion_matrix: Whether to compute confusion matrix
        dataset_name: Name of dataset for class names
        original_num_classes: Number of original classes (before grouping)
        
    Returns:
        dict: Evaluation metrics for each block
    """
    model.eval()
    layers_list = list(model.layers.items())

    if block_name is None:
        target_blocks = [name for name, layer in layers_list if isinstance(layer, FFLayer)]
    elif isinstance(block_name, str):
        target_blocks = [block_name]
    else:
        target_blocks = block_name

    block_outputs = {
        name: {
            "scores": [], 
            "preds": [], 
            "labels": [],
            "grouped_scores": [],
            "grouped_preds": [],
            "group_mapping": None,
            "grouping_mode": None
        } 
        for name in target_blocks
    }

    with torch.no_grad():
        for x, y in dataloader:
            x, y = x.to(device), y.to(device)
            layer_outputs = {"input": x}
            curr_x = x

            for idx, (layer_name, layer) in enumerate(layers_list):
                if isinstance(layer, SkipConnection):
                    skip_input = layer_outputs[layer.skip_from]
                    curr_x = layer(curr_x, skip_input)
                elif isinstance(layer, FFLayer) and layer_name in target_blocks:
                    # Get predictions in the layer's native space
                    native_scores, native_preds, out = layer.predict(curr_x)
                    
                    # Store native predictions (could be grouped or original space)
                    block_outputs[layer_name]["grouped_scores"].append(native_scores)
                    block_outputs[layer_name]["grouped_preds"].append(native_preds)
                    
                    # Determine grouping mode and handle accordingly
                    grouping_mode = None
                    group_mapping = None
                    if hasattr(layer.layer, 'class_grouping_manager') and layer.layer.class_grouping_manager is not None:
                        grouping_mode = layer.layer.class_grouping_manager.mode.value
                        block_outputs[layer_name]["grouping_mode"] = grouping_mode
                        
                        if layer.layer.class_grouping_manager.is_dimension_reduction_mode():
                            group_mapping = layer.layer.get_grouped_prediction_mapping()
                            block_outputs[layer_name]["group_mapping"] = group_mapping
                    
                    # Handle different grouping modes
                    if grouping_mode == ClassGroupingMode.DIMENSION_REDUCTION.value and group_mapping is not None:
                        # DIMENSION_REDUCTION mode: Map from grouped space to original space
                        batch_size = native_scores.shape[0]
                        original_scores = torch.zeros(batch_size, original_num_classes, device=device)
                        
                        # Map grouped scores back to original classes
                        for group_id, original_classes in group_mapping.items():
                            if group_id < native_scores.shape[1]:
                                for orig_class in original_classes:
                                    if orig_class < original_num_classes:
                                        original_scores[:, orig_class] = native_scores[:, group_id]
                        
                        # Map predictions back to original space
                        original_preds = map_grouped_predictions_to_original(
                            native_preds, group_mapping, strategy="first"
                        )
                        
                        scores = original_scores
                        preds = original_preds
                        
                    elif grouping_mode == ClassGroupingMode.GROUP_AWARE_NEGATIVE.value:
                        # GROUP_AWARE_NEGATIVE mode: Already in original space, no mapping needed
                        scores = native_scores
                        preds = native_preds
                        
                    else:
                        # No grouping or unknown mode
                        scores = native_scores
                        preds = native_preds
                    
                    block_outputs[layer_name]["scores"].append(scores)
                    block_outputs[layer_name]["preds"].append(preds)
                    block_outputs[layer_name]["labels"].append(y)
                    curr_x = out
                else:
                    curr_x = layer(curr_x)

                if any(
                    isinstance(future_layer, SkipConnection) and future_layer.skip_from == layer_name
                    for _, future_layer in layers_list[idx + 1:]
                ):
                    layer_outputs[layer_name] = curr_x

    # Class names for different datasets
    classes = {
        "CIFAR-10": {
            0: "Airplane", 1: "Automobile", 2: "Bird", 3: "Cat", 4: "Deer",
            5: "Dog", 6: "Frog", 7: "Horse", 8: "Ship", 9: "Truck"
        },
        "CIFAR-100": {
            0: "Beaver", 1: "Dolphin", 2: "Otter", 3: "Seal", 4: "Whale",
            5: "Aquarium fish", 6: "Flatfish", 7: "Ray", 8: "Shark", 9: "Trout",
            10: "Orchid", 11: "Poppy", 12: "Rose", 13: "Sunflower", 14: "Tulip",
            15: "Bottle", 16: "Bowl", 17: "Can", 18: "Cup", 19: "Plate",
            20: "Apple", 21: "Mushroom", 22: "Orange", 23: "Pear", 24: "Sweet pepper",
            25: "Clock", 26: "Computer keyboard", 27: "Lamp", 28: "Telephone", 29: "Television",
            30: "Bed", 31: "Chair", 32: "Couch", 33: "Table", 34: "Wardrobe",
            35: "Bee", 36: "Beetle", 37: "Butterfly", 38: "Caterpillar", 39: "Cockroach",
            40: "Bear", 41: "Leopard", 42: "Lion", 43: "Tiger", 44: "Wolf",
            45: "Bridge", 46: "Castle", 47: "House", 48: "Road", 49: "Skyscraper",
            50: "Cloud", 51: "Forest", 52: "Mountain", 53: "Plain", 54: "Sea",
            55: "Camel", 56: "Cattle", 57: "Chimpanzee", 58: "Elephant", 59: "Kangaroo",
            60: "Fox", 61: "Porcupine", 62: "Possum", 63: "Raccoon", 64: "Skunk",
            65: "Crab", 66: "Lobster", 67: "Snail", 68: "Spider", 69: "Worm",
            70: "Baby", 71: "Boy", 72: "Girl", 73: "Man", 74: "Woman",
            75: "Crocodile", 76: "Dinosaur", 77: "Lizard", 78: "Snake", 79: "Turtle",
            80: "Hamster", 81: "Mouse", 82: "Rabbit", 83: "Shrew", 84: "Squirrel",
            85: "Maple", 86: "Oak", 87: "Palm", 88: "Pine", 89: "Willow",
            90: "Bicycle", 91: "Bus", 92: "Motorcycle", 93: "Pickup truck", 94: "Train",
            95: "Lawn-mower", 96: "Rocket", 97: "Streetcar", 98: "Tank", 99: "Tractor"
        }
    }

    block_metrics = {}
    criterion = torch.nn.BCEWithLogitsLoss()
    
    for block_name, data in block_outputs.items():
        if len(data["scores"]) == 0:
            continue
            
        scores = torch.cat(data["scores"], dim=0)
        preds = torch.cat(data["preds"], dim=0)
        labels = torch.cat(data["labels"], dim=0)
        
        # Also store grouped/native metrics for analysis
        if len(data["grouped_scores"]) > 0:
            grouped_scores = torch.cat(data["grouped_scores"], dim=0)
            grouped_preds = torch.cat(data["grouped_preds"], dim=0)

        y_true = labels#.cpu().numpy()
        y_pred = preds#.cpu().numpy()

        # Compute metrics in original class space
        num_classes = scores.size(-1)
        targets = torch.nn.functional.one_hot(labels, num_classes=num_classes).float()
        loss = criterion(scores, targets).item()

        predicted = scores.argmax(dim=-1)
        accuracy = (predicted == labels).float().mean().item()
        
        # Compute native/grouped accuracy if applicable
        native_accuracy = None
        grouping_mode = data.get("grouping_mode")
        
        if len(data["grouped_scores"]) > 0 and grouping_mode is not None:
            if grouping_mode == ClassGroupingMode.DIMENSION_REDUCTION.value:
                # In dimension reduction mode, compute grouped accuracy
                group_mapping = data["group_mapping"]
                if group_mapping is not None:
                    # Transform original labels to grouped space
                    original_to_group = {}
                    for group_id, orig_classes in group_mapping.items():
                        for orig_class in orig_classes:
                            original_to_group[orig_class] = group_id
                    
                    # Transform labels to grouped space
                    grouped_labels = torch.zeros_like(labels)
                    for i, label in enumerate(labels):
                        orig_label = label.item()
                        if orig_label in original_to_group:
                            grouped_labels[i] = original_to_group[orig_label]
                        else:
                            # Find the group ID for ungrouped classes
                            max_group_id = max(group_mapping.keys()) if group_mapping else -1
                            grouped_labels[i] = orig_label + max_group_id + 1 - len([c for group in group_mapping.values() for c in group])
                    
                    native_accuracy = (grouped_preds == grouped_labels).float().mean().item()
                    
            elif grouping_mode == ClassGroupingMode.GROUP_AWARE_NEGATIVE.value:
                # In group-aware negative mode, native accuracy should be same as regular accuracy
                # since we're already in original space
                native_accuracy = (grouped_preds == labels).float().mean().item()

        block_metrics[block_name] = {
            "accuracy": accuracy,
            "loss": loss,
            "grouping_mode": grouping_mode
        }
        
        if native_accuracy is not None:
            if grouping_mode == ClassGroupingMode.DIMENSION_REDUCTION.value:
                block_metrics[block_name]["grouped_accuracy"] = native_accuracy
            else:
                block_metrics[block_name]["native_accuracy"] = native_accuracy

        print(f"[{split}] {block_name} >> Loss: {loss:.4f} - Accuracy: {accuracy:.4f}")
        
        if grouping_mode is not None:
            print(f"[{split}] {block_name} >> Grouping Mode: {grouping_mode}")
            if native_accuracy is not None:
                if grouping_mode == ClassGroupingMode.DIMENSION_REDUCTION.value:
                    print(f"[{split}] {block_name} >> Grouped Space Accuracy: {native_accuracy:.4f}")
                else:
                    print(f"[{split}] {block_name} >> Native Accuracy: {native_accuracy:.4f}")
            
            # Log class grouping information
            if data["group_mapping"] is not None:
                print(f"[{split}] {block_name} >> Class Groups: {data['group_mapping']}")

        if wandb_enabled:
            wandb_logs = {
                f"{wandb_prefix}{block_name}/{split}/loss": loss,
                f"{wandb_prefix}{block_name}/{split}/accuracy": accuracy,
                "epoch": step
            }
            
            if native_accuracy is not None:
                if grouping_mode == ClassGroupingMode.DIMENSION_REDUCTION.value:
                    wandb_logs[f"{wandb_prefix}{block_name}/{split}/grouped_accuracy"] = native_accuracy
                else:
                    wandb_logs[f"{wandb_prefix}{block_name}/{split}/native_accuracy"] = native_accuracy
            
            if grouping_mode is not None:
                wandb_logs[f"{wandb_prefix}{block_name}/{split}/grouping_mode"] = grouping_mode

            if precision or recall or f1:
                prf_args = {
                    "labels": list(range(num_classes)),
                    "average": None,
                    "zero_division": 0
                }
                precisions, recalls, f1s, _ = precision_recall_fscore_support(y_true, y_pred, **prf_args)

                if precision:
                    wandb_logs.update({
                        f"{wandb_prefix}{block_name}/{split}/precision/{classes[dataset_name][i]}": precisions[i]
                        for i in range(min(num_classes, len(classes[dataset_name])))
                    })
                if recall:
                    wandb_logs.update({
                        f"{wandb_prefix}{block_name}/{split}/recall/{classes[dataset_name][i]}": recalls[i]
                        for i in range(min(num_classes, len(classes[dataset_name])))
                    })
                if f1:
                    wandb_logs.update({
                        f"{wandb_prefix}{block_name}/{split}/f1/{classes[dataset_name][i]}": f1s[i]
                        for i in range(min(num_classes, len(classes[dataset_name])))
                    })

            # if confusion_matrix and step == (num_epochs - 1):
            #     class_names = [
            #         classes[dataset_name].get(i, f"Class_{i}") 
            #         for i in range(num_classes)
            #     ]
            #     wandb_cm = wandb.plot.confusion_matrix(
            #         probs=None,
            #         y_true=y_true,
            #         preds=y_pred,
            #         class_names=class_names,
            #         title=f"{split} confusion matrix",
            #     )
            #     wandb_logs[f"{wandb_prefix}{block_name}/{split}/confusion_matrix"] = wandb_cm

            wandb.log(wandb_logs, step=step)

            # block_metrics[block_name]["confusion_matrix"] = _confusion_matrix(y_true, y_pred, num_classes)
    return block_metrics


# Backward compatibility - use the enhanced evaluation as default
def evaluate(*args, **kwargs):
    """Backward compatible evaluation function."""
    # Set default for original_num_classes if not provided
    if 'original_num_classes' not in kwargs:
        kwargs['original_num_classes'] = 10
    return evaluate_with_class_grouping(*args, **kwargs)














def cache_goodnesses(
    model: torch.nn.Module,
    dataloader: torch.utils.data.DataLoader,
    device: str,
) -> dict:
    """
    Evaluate FF classification accuracy using each block's predict() method.

    Returns:
        dict with:
            - block_accuracies (List[float])
            - goodness_per_block (List[Tensor]): (N, C) for each block
            - labels (Tensor): ground-truth (N,)
    """
    model.eval()
    num_blocks = len(model.trainable_names)

    goodnesses = [[] for _ in range(num_blocks)]
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for x, y in dataloader:
            x, y = x.to(device), y.to(device)
            all_labels.append(y)

            block_idx = 0
            for layer_idx, layer_name in enumerate(model.layers):
                layer = model.layers[layer_name]
                if isinstance(layer, FFLayer):
                    goodness, preds, x = layer.predict(x)
                    goodnesses[block_idx].append(goodness)
                    all_preds.append(preds)
                    block_idx += 1
                else:
                    x = layer(x)
                # if layer_name == block_name:
                #     break

    labels = torch.cat(all_labels, dim=0)

    return {
        "goodnesses": goodnesses,
        "labels": labels
    }
