import numpy as np
import torch
import torch.nn.functional as F

from src.iees_utils import generate_pfams


class ExitAttributionCache:
    def __init__(self):
        self.activations = {}   # exit_idx → activation tensor
        self.outputs = {}       # exit_idx → output logits
        self.class_ids = {}     # exit_idx → predicted class index

    def cache(self, exit_idx, model, output, class_idx):
        self.outputs[exit_idx] = output.clone()
        self.class_ids[exit_idx] = class_idx.item()
        self.activations[exit_idx] = model.activations.detach().clone()


def generate_pfams_after_exit(model, input_tensor, final_exit_idx, activation_cache, output_cache, class_cache):
    """
    Generate PFAM maps from exit 0 to final_exit_idx using cached data.
    """
    pfams_list = []

    for i in range(final_exit_idx + 1):
        model.zero_grad()

        # 1. Trigger backprop on class score
        class_score = output_cache[i][0, class_cache[i]]
        class_score.backward(retain_graph=True)

        # 2. Get gradients and activations
        gradients = model.gradients.detach()
        activations = activation_cache[i]

        # 3. Compute Grad-CAM
        weights = gradients.mean(dim=(2, 3), keepdim=True)
        cam = torch.relu((weights * activations).sum(dim=1))
        cam = F.interpolate(cam.unsqueeze(1), input_tensor.shape[2:], mode='bilinear', align_corners=False)
        cam = cam.detach().squeeze().cpu().numpy()

        # Normalize
        cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-10)
        pfams_list.append(cam)

    # 4. Compute cumulative PFAM (progressive weighted average)
    weights = np.arange(1, len(pfams_list) + 1)
    weights = weights / weights.sum()
    cumulative_map = np.tensordot(weights, pfams_list, axes=(0, 0))
    print("++++++++++++++++++++++++",cumulative_map[0].ndim, cumulative_map[0].shape[0])
    return pfams_list, cumulative_map


def generate_pfams_after_exit1(model, input_tensor, n_exits):

    predicted_classes = []
    pfams_list = []
    cumulative_maps = []
    for exit_idx in range(n_exits):
        print("Processing exit:", exit_idx)
        # Forward pass to current exit
        output, class_idx, target_score = model.forward_to_exit(input_tensor, exit_idx,True)
        predicted_classes.append(class_idx.item())
        print(model.activations)
        pfams_list, cumulative_map = generate_pfams(model.activations, model.gradients, input_tensor, pfams_list)
        cumulative_maps.append(cumulative_map)
        print(cumulative_maps[0].ndim,cumulative_maps[0].shape[0])

    return pfams_list, cumulative_maps

