"""
Circle vs. Square Image Classification with 4 Captum Explanation Methods:
  - Integrated Gradients
  - Occlusion
  - Shapley Value Sampling
  - Captum's LIME

Now includes:
  - Spurious scenario (corner pixel for label=1).
  - R, F, IS with a new Inversion Score
  - RBP (Reproduce-by-Poking)

Usage:
  python image_experiment.py
It will print and log results to 'image_captum_expls_log.txt'.
"""

import numpy as np
import random
import cv2
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.utils.data import Dataset, DataLoader

# Captum-based explanation classes
from captum.attr import IntegratedGradients, Occlusion, ShapleyValueSampling, Lime

import matplotlib.pyplot as plt
import os

# ==============================
# 1) Synthetic Circle vs Square
# ==============================
def generate_synthetic_images(n_samples=200, img_size=32):
    """
    Creates random images: half circles => label=0, half squares => label=1.
    Returns:
      X: shape (n_samples, H, W), [0..255]
      y: shape (n_samples,)
      bboxes: bounding boxes for the shape (for 'G' measurement, if desired)
    """
    X = []
    y = []
    bboxes = []
    half = n_samples // 2

    for _ in range(half):
        # Circle => label=0
        img = np.zeros((img_size, img_size), dtype=np.uint8)
        radius = random.randint(img_size // 6, img_size // 4)
        cx = random.randint(radius, img_size - radius - 1)
        cy = random.randint(radius, img_size - radius - 1)
        cv2.circle(img, (cx, cy), radius, 255, -1)
        X.append(img)
        y.append(0)
        bboxes.append((cx - radius, cy - radius, cx + radius, cy + radius))

    for _ in range(n_samples - half):
        # Square => label=1
        img = np.zeros((img_size, img_size), dtype=np.uint8)
        side = random.randint(img_size // 4, img_size // 2)
        x1 = random.randint(0, img_size - side - 1)
        y1 = random.randint(0, img_size - side - 1)
        cv2.rectangle(img, (x1, y1), (x1 + side, y1 + side), 255, -1)
        X.append(img)
        y.append(1)
        bboxes.append((x1, y1, x1 + side, y1 + side))

    X = np.array(X, dtype=np.float32)
    y = np.array(y, dtype=np.int64)
    return X, y, bboxes

class CircleSquareDataset(Dataset):
    def __init__(self, imgs, labels, bboxes):
        self.imgs = imgs  # shape (N, H, W)
        self.labels = labels
        self.bboxes = bboxes

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        x = self.imgs[idx]  # shape(H, W)
        y = self.labels[idx]
        bb = self.bboxes[idx]
        # expand => (1,H,W)
        x = np.expand_dims(x, axis=0)
        return torch.from_numpy(x), y, bb


# ==============================
# 2) A Simple CNN
# ==============================
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=2):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, 3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(32 * 8 * 8, 32)
        self.fc2 = nn.Linear(32, num_classes)

    def forward(self, x):
        # x shape => (N,1,H,W)
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        # flatten
        x = x.contiguous().view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


# ==============================
# 3) Training
# ==============================
def train_model(model, train_loader, val_loader, device, epochs=10):
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()
    model = model.to(device)

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        for bx, by, _ in train_loader:
            bx = bx.to(device, dtype=torch.float)
            by = by.to(device)
            optimizer.zero_grad()
            out = model(bx)
            loss = criterion(out, by)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * bx.size(0)
            _, preds = torch.max(out, 1)
            correct += (preds == by).sum().item()
            total += bx.size(0)
        epoch_loss = running_loss / total
        epoch_acc = correct / total

        # quick val
        val_acc = 0.0
        if val_loader:
            model.eval()
            vcorrect = 0
            vtotal = 0
            with torch.no_grad():
                for vx, vy, _ in val_loader:
                    vx = vx.to(device, dtype=torch.float)
                    vy = vy.to(device)
                    vout = model(vx)
                    _, vp = torch.max(vout, 1)
                    vcorrect += (vp == vy).sum().item()
                    vtotal += vx.size(0)
            val_acc = vcorrect / vtotal
        print(f"Epoch {epoch+1}/{epochs}, Loss={epoch_loss:.4f}, TrainAcc={epoch_acc:.4f}, ValAcc={val_acc:.4f}")


# ==============================
# 4) Captum Explanation Methods
# ==============================
def explain_ig(model, input_tensor, target_class, device):
    ig = IntegratedGradients(model)
    attributions = ig.attribute(input_tensor, target=target_class)
    sal_map = attributions[0, 0, :, :].detach().cpu().numpy()
    return sal_map

def explain_occlusion(model, input_tensor, target_class, device):
    occ = Occlusion(model)
    baseline = torch.zeros_like(input_tensor)
    attributions = occ.attribute(
        input_tensor,
        target=target_class,
        strides=(1,4,4),
        sliding_window_shapes=(1,4,4),
        baselines=baseline
    )
    sal_map = attributions[0, 0, :, :].detach().cpu().numpy()
    return sal_map

def explain_shapley(model, input_tensor, target_class, device):
    svs = ShapleyValueSampling(model)
    attributions = svs.attribute(
        input_tensor,
        target=target_class,
        n_samples=3
    )
    sal_map = attributions[0, 0, :, :].detach().cpu().numpy()
    return sal_map

def explain_lime_captum(model, input_tensor, target_class, device):
    lime_obj = Lime(model)
    attributions = lime_obj.attribute(
        input_tensor,
        target=target_class,
        n_samples=3
    )
    sal_map = attributions[0, 0, :, :].detach().cpu().numpy()
    return sal_map


def _explain_image(model, input_tensor, target_class, device, method):
    """
    Return the saliency map for a single input_tensor (shape (1,1,H,W)) given a method string.
    """
    model.eval()
    if method == "ig":
        sal_map = explain_ig(model, input_tensor, target_class, device)
    elif method == "occlusion":
        sal_map = explain_occlusion(model, input_tensor, target_class, device)
    elif method == "shapley":
        sal_map = explain_shapley(model, input_tensor, target_class, device)
    elif method == "lime":
        sal_map = explain_lime_captum(model, input_tensor, target_class, device)
    else:
        raise ValueError(f"Unknown method={method}")

    # Take absolute value for clarity
    sal_map = np.abs(sal_map)
    # optional normalization
    if sal_map.max() > 1e-9:
        sal_map /= sal_map.max()
    return sal_map


# ==============================
# 5) R, F, IS, and RBP for Images
# ==============================
def compute_inversion_scores_image(
    model,
    dataset,
    device,
    explanation_method="ig",
    n_samples=20,
    perturb_scale=0.1,
    p=2
):
    """
    We'll define R, F, then combine into IS = ((R^p + (1-F)^p)/2)^(1/p).
    R => correlation of delta-sal vs delta-output
    F => big difference in output when occluding top region, or re-run explanation?
    For demonstration, we'll do a simple approach:
      - R: pick some random pixels, perturb them slightly, measure delta-prob vs. original sal map
      - F: measure how drastically the output changes if we zero out the top-5 salient pixels
    """
    indices = np.random.choice(len(dataset), size=min(n_samples, len(dataset)), replace=False)
    R_vals, F_vals = [], []

    model.eval().to(device)

    # Precompute baseline predictions
    for idx in indices:
        x_img, label, _ = dataset[idx]
        x_img_batch = x_img.unsqueeze(0).float().to(device)

        # predicted class
        with torch.no_grad():
            out = model(x_img_batch)
            pred_class = torch.argmax(out, dim=1).item()
            prob_orig = F.softmax(out, dim=1)[0, pred_class].item()

        # baseline sal
        sal_map = _explain_image(model, x_img_batch, pred_class, device, explanation_method)
        H, W = sal_map.shape

        # ---- R measure ----
        local_deltas = []
        local_sals = []
        n_pixels_to_perturb = 5
        for _ in range(n_pixels_to_perturb):
            px = random.randint(0, H-1)
            py = random.randint(0, W-1)

            # perturb
            x_pert = x_img_batch.clone()
            x_pert[0,0,px,py] += perturb_scale

            with torch.no_grad():
                out_pert = model(x_pert)
                prob_pert = F.softmax(out_pert, dim=1)[0, pred_class].item()

            delta_m = prob_pert - prob_orig
            delta_a = sal_map[px,py]  # approximate (not re-running explanation)
            local_sals.append(delta_a)
            local_deltas.append(delta_m)

        if np.std(local_sals) < 1e-9 or np.std(local_deltas) < 1e-9:
            R_local = 0.0
        else:
            corr = np.corrcoef(local_sals, local_deltas)[0,1]
            R_local = max(0,corr)
        R_vals.append(R_local)

        # ---- F measure ----
        # We'll approximate faithfulness as how much the model's predicted probability changes if we zero out
        # the top 5% highest-sal pixels
        flat_sal = sal_map.flatten()
        n_mask = max(1, int(0.05 * len(flat_sal)))
        top_indices = np.argsort(-flat_sal)[:n_mask]  # top n_mask
        x_masked = x_img_batch.clone()
        for idx_ in top_indices:
            px = idx_ // W
            py = idx_ % W
            x_masked[0,0,px,py] = 0.0

        with torch.no_grad():
            out_masked = model(x_masked)
            prob_masked = F.softmax(out_masked, dim=1)[0, pred_class].item()

        diff_prob = abs(prob_orig - prob_masked)
        # map it to [0,1]
        F_local = diff_prob if diff_prob<1.0 else 1.0
        F_vals.append(F_local)

    R_val = np.mean(R_vals) if len(R_vals)>0 else 0.0
    F_val = np.mean(F_vals) if len(F_vals)>0 else 0.0

    # Inversion Score
    # IS = ((R^p + (1-F)^p)/2)^(1/p)
    IS_val = ((R_val**p + (1-F_val)**p)/2.0)**(1.0/p)

    return {
        "R": R_val,
        "F": F_val,
        "IS": IS_val
    }

def apply_rbp_image(
    model,
    x_img_batch,
    baseline_sal,
    pred_class,
    device,
    explanation_method="ig",
    n_pert=3,
    perturb_scale=0.05,
    lambda_=1.0
):
    """
    For a single image, do RBP on the saliency map:
      - For each pixel, do small perturbations -> measure how that pixel's sal value changes
      - Then a'^pixel = a^pixel / [1 + lambda_ * delta_pixel]
    Returns a refined saliency map (same shape as baseline_sal).
    """
    H, W = baseline_sal.shape
    refined_map = baseline_sal.copy()

    for px in range(H):
        for py in range(W):
            local_dev_sum = 0.0
            base_sal_val = baseline_sal[px, py]

            for _ in range(n_pert):
                x_pert = x_img_batch.clone()
                x_pert[0,0,px,py] += np.random.normal(0, perturb_scale)

                # re-run explanation for just this sample
                new_sal = _explain_image(model, x_pert, pred_class, device, explanation_method)
                local_dev_sum += abs(new_sal[px,py] - base_sal_val)

            avg_dev = local_dev_sum / (n_pert + 1e-9)
            refined_map[px,py] = base_sal_val / (1.0 + lambda_ * avg_dev)

    # optionally re-normalize
    max_val = refined_map.max()
    if max_val>1e-9:
        refined_map /= max_val

    return refined_map


# ==============================
# 6) Spurious
# ==============================
def create_spurious_test(X_test, y_test):
    """
    If label=1 => bright corner pixel => artificially correlated with label=1.
    """
    X_sp = X_test.copy()
    for i in range(len(X_sp)):
        if y_test[i] == 1:
            X_sp[i, 0, 0] = 255
    return X_sp


# ==============================
# 7) Visualization
# ==============================
def visualize_saliency_maps(
    model,
    test_dataset,
    spurious_dataset,
    device,
    method,
    out_fig="saliency_visuals/ig/saliency_comparison.png"
):
    os.makedirs(os.path.dirname(out_fig), exist_ok=True)
    model.eval().to(device)

    n_samples = 3
    indices = np.random.choice(len(test_dataset), n_samples, replace=False)

    fig, axes = plt.subplots(n_samples, 3, figsize=(9, 3*n_samples))
    if n_samples == 1:
        axes = [axes]

    for row_idx, idx in enumerate(indices):
        x_img, y_true, _ = test_dataset[idx]
        x_sp, _, _ = spurious_dataset[idx]

        raw_img_normal = x_img[0].numpy()
        raw_img_spurious = x_sp[0].numpy()

        # normal
        inp_normal = x_img.unsqueeze(0).float().to(device)
        with torch.no_grad():
            out_normal = model(inp_normal)
            pred_class_normal = torch.argmax(out_normal, dim=1).item()

        sal_normal = _explain_image(model, inp_normal, pred_class_normal, device, method)

        # spurious
        inp_spurious = x_sp.unsqueeze(0).float().to(device)
        with torch.no_grad():
            out_spurious = model(inp_spurious)
            pred_class_spurious = torch.argmax(out_spurious, dim=1).item()

        sal_spurious = _explain_image(model, inp_spurious, pred_class_spurious, device, method)

        # Plot the original normal image in the first col
        axes[row_idx][0].imshow(raw_img_normal, cmap='gray')
        axes[row_idx][0].set_title(f"Orig\n(label={y_true})")
        axes[row_idx][0].axis('off')

        # normal sal
        axes[row_idx][1].imshow(sal_normal, cmap='hot')
        axes[row_idx][1].set_title(f"{method}\n(normal)")
        axes[row_idx][1].axis('off')

        # spurious sal
        axes[row_idx][2].imshow(sal_spurious, cmap='hot')
        axes[row_idx][2].set_title(f"{method}\n(spurious)")
        axes[row_idx][2].axis('off')

    plt.tight_layout()
    plt.savefig(out_fig, dpi=150)
    plt.close()


# ==============================
# 8) Main + Logging
# ==============================
def run_image_experiment_expls(log_filename="image_captum_expls_log.txt"):
    # 1) Data
    X, y, bboxes = generate_synthetic_images(n_samples=200, img_size=32)
    # train/test split
    split = 160
    idxs = np.arange(len(X))
    np.random.shuffle(idxs)
    train_idxs = idxs[:split]
    test_idxs = idxs[split:]
    X_train, y_train = X[train_idxs], y[train_idxs]
    bboxes_train = [bboxes[i] for i in train_idxs]
    X_test, y_test = X[test_idxs], y[test_idxs]
    bboxes_test = [bboxes[i] for i in test_idxs]

    train_ds = CircleSquareDataset(X_train, y_train, bboxes_train)
    test_ds = CircleSquareDataset(X_test, y_test, bboxes_test)
    train_loader = DataLoader(train_ds, batch_size=16, shuffle=True)
    test_loader = DataLoader(test_ds, batch_size=16, shuffle=False)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = SimpleCNN(num_classes=2)
    train_model(model, train_loader, test_loader, device, epochs=10)

    # Evaluate
    model.eval()
    correct = 0
    total = 0
    for bx, by, _ in test_loader:
        bx = bx.to(device, dtype=torch.float)
        by = by.to(device)
        with torch.no_grad():
            out = model(bx)
        preds = torch.argmax(out, dim=1)
        correct += (preds == by).sum().item()
        total += len(by)
    test_acc = correct / total
    print(f"Test Accuracy: {test_acc:.4f}")

    # Explanation methods
    expl_methods = ["ig", "occlusion", "shapley", "lime"]

    # Spurious test
    X_test_sp = create_spurious_test(X_test, y_test)
    test_sp_ds = CircleSquareDataset(X_test_sp, y_test, bboxes_test)

    with open(log_filename, "w") as f:
        f.write("Image Explanation with Captum (R, F, IS, RBP)\n")
        f.write("="*60 + "\n")
        f.write(f"TestAccuracy={test_acc:.4f}\n")

    for method in expl_methods:
        print(f"\n=== Explanation: {method} ===")
        # Baseline Inversion Scores
        start_t = time.time()
        iq_base = compute_inversion_scores_image(model, test_ds, device, explanation_method=method, n_samples=20)
        elapsed = time.time() - start_t

        # RBP: refine the saliency maps for each sample, then compute new IS
        # We'll do it for a small subset to keep it quick.
        rbp_subset = np.random.choice(len(test_ds), size=10, replace=False)
        # We'll store the refined maps in memory, then compute R,F from them
        # (Note: this is somewhat approximate, but demonstrates the concept.)

        # We'll define new "R_val", "F_val" from the refined maps for those 10 samples
        R_vals_rbp, F_vals_rbp = [], []

        for idx in rbp_subset:
            x_img, label, _ = test_ds[idx]
            x_img_batch = x_img.unsqueeze(0).float().to(device)
            with torch.no_grad():
                out_ = model(x_img_batch)
                pred_class = torch.argmax(out_, dim=1).item()
                prob_orig = F.softmax(out_, dim=1)[0, pred_class].item()

            # baseline map
            base_sal = _explain_image(model, x_img_batch, pred_class, device, method)
            # refine with RBP
            refined_sal = apply_rbp_image(
                model, x_img_batch, base_sal, pred_class, device,
                explanation_method=method,
                n_pert=2,
                perturb_scale=0.05,
                lambda_=1.0
            )

            # For R, do the partial correlation trick
            local_deltas = []
            local_sals = []
            H, W = refined_sal.shape
            for _ in range(5):
                px = random.randint(0, H-1)
                py = random.randint(0, W-1)
                x_pert = x_img_batch.clone()
                x_pert[0,0,px,py] += 0.1
                with torch.no_grad():
                    out_pert = model(x_pert)
                    prob_pert = F.softmax(out_pert, dim=1)[0, pred_class].item()
                delta_m = prob_pert - prob_orig
                delta_a = refined_sal[px,py]
                local_deltas.append(delta_m)
                local_sals.append(delta_a)

            if np.std(local_sals) < 1e-9 or np.std(local_deltas) < 1e-9:
                R_local = 0.0
            else:
                corr_ = np.corrcoef(local_sals, local_deltas)[0,1]
                R_local = max(0, corr_)
            R_vals_rbp.append(R_local)

            # For F, zero out top 5% of refined_sal
            flat_sal = refined_sal.flatten()
            topN = max(1, int(0.05 * len(flat_sal)))
            top_indices = np.argsort(-flat_sal)[:topN]
            x_masked = x_img_batch.clone()
            for idx_2 in top_indices:
                px2 = idx_2 // W
                py2 = idx_2 % W
                x_masked[0,0,px2,py2] = 0.0

            with torch.no_grad():
                out_masked = model(x_masked)
                prob_masked = F.softmax(out_masked, dim=1)[0, pred_class].item()
            diff_ = abs(prob_orig - prob_masked)
            F_local = diff_ if diff_<1.0 else 1.0
            F_vals_rbp.append(F_local)

        R_rbp = np.mean(R_vals_rbp) if len(R_vals_rbp)>0 else 0.0
        F_rbp = np.mean(F_vals_rbp) if len(F_vals_rbp)>0 else 0.0
        # IS
        p=2
        IS_rbp = ((R_rbp**p + (1-F_rbp)**p)/2.0)**(1.0/p)
        elapsed_rbp = time.time() - start_t

        # Spurious base
        iq_sp_base = compute_inversion_scores_image(
            model, test_sp_ds, device,
            explanation_method=method,
            n_samples=20
        )
        # We won't fully replicate the RBP step for spurious to keep code short,
        # but you can do so similarly.

        with open(log_filename, "a") as ff:
            ff.write(f"Method=[{method}]\n")
            ff.write(
                f"BASE => R={iq_base['R']:.3f}, F={iq_base['F']:.3f}, IS={iq_base['IS']:.3f}, Time={elapsed:.2f}s\n"
            )
            ff.write(
                f"RBP  => R={R_rbp:.3f}, F={F_rbp:.3f}, IS={IS_rbp:.3f}, Time={elapsed_rbp:.2f}s\n"
            )
            ff.write(
                f"Spur=> R={iq_sp_base['R']:.3f}, F={iq_sp_base['F']:.3f}, IS={iq_sp_base['IS']:.3f}\n"
            )
            ff.write("-"*60 + "\n")

        # Visualization
        out_dir = f"saliency_visuals/{method}"
        out_figure_path = os.path.join(out_dir, "saliency_comparison.png")
        visualize_saliency_maps(model, test_ds, test_sp_ds, device, method, out_fig=out_figure_path)
        with open(log_filename,"a") as ff:
            ff.write(f"Saliency figure for method [{method}] => {out_figure_path}\n\n")

    print(f"\nResults logged to: {log_filename}")


if __name__ == "__main__":
    run_image_experiment_expls()
