import os
import shap
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from utils.image import Image  # your custom SHAP masker

class ShapMaskEvaluator:
    def __init__(self, model, class_names, csv_save_path, masker_method="mdp", device="cuda"):
        """
        Args:
            model: PyTorch model.
            class_names: List of class names corresponding to output indices.
            csv_save_path: Path to save the evaluation results CSV.
            masker_method: String identifier for your custom SHAP masker (e.g. "mdp").
            device: "cuda" or "cpu".
        """
        self.model = model.eval().to(device)
        self.class_names = class_names
        self.csv_save_path = csv_save_path
        self.masker_method = masker_method
        self.device = device
        self.explainer = None


    def evaluate(self, val_loader, segm_loader, predict):
        for batch_idx, (image_tensor, label, img_name) in enumerate(tqdm(val_loader)):
            image_tensor = image_tensor.to(self.device)  # [1, 3, H, W]
            label = label.to(self.device)

            image_nhwc = image_tensor.permute(0, 2, 3, 1).cpu().numpy()  # [1, H, W, 3]

            # Create SHAP masker using the specified method
            masker = Image(self.masker_method, image_nhwc.shape[1:])
            self.explainer = shap.Explainer(predict, masker, output_names=self.class_names)

            # Compute SHAP values for GT class
            shap_values = self.explainer(image_nhwc, max_evals=1000, outputs=label.cpu().numpy())
            gt_shap = shap_values.values[0]  # [1, H, W, 3]
            saliency = np.mean(gt_shap.squeeze(), axis=-1)  # [1, H, W]


            # Load segmentation mask
            segmented_tensor, _ = next(self._get_batch(segm_loader, batch_idx))
            segmented_tensor = segmented_tensor.to(self.device)  # [1, 1, H, W]

            # Compute SHAP sum inside the segmentation mask
            shap_sum = self._calc_shap_sum_in_mask(saliency, segmented_tensor)

            results = [{
                "img_name": img_name[0],
                "gt_class": self.class_names[label.item()],
                "shap_sum_in_mask": shap_sum
            }]

            df = pd.DataFrame(results)
            df.to_csv(self.csv_save_path, mode="a", index=False, header=not os.path.exists(self.csv_save_path))

    def _calc_shap_sum_in_mask(self, shap_saliency: np.ndarray, segm_tensor: torch.Tensor) -> float:
        segm_mask = segm_tensor.squeeze(0).squeeze(0).cpu().numpy()  # [H, W]
        return float(np.sum(shap_saliency * segm_mask))

    def _get_batch(self, loader, target_idx):
        for idx, batch in enumerate(loader):
            if idx == target_idx:
                yield batch
                return
