import torch
import model.metrics as metrics
from torch.autograd.functional import jacobian


def compute_layerwise_metrics(
    signal, sparse_codes, dict, lambda1, lambda2, constraint_type="energy", reconstructions=None, 
    vit_processor=None
):
    """
    Given a signal, sparse codes (for each layer) and parameters, computes the layerwise constraints for the given signal.
    ===
    Inputs:
        constraint_type: The type of constraint to compute. Supported: 'energy', or 'energy_jacobian'
        reconstructions: The reconstructions from each layer. (T,L,B,C,W,H)
        vit_processor: Optional ViT processor for inverting normalized images before PSNR calculation
    Returns:
        A vector of energy differences (E_{l+1} - alpha * E_{l})
    """
    num_layers = sparse_codes.shape[1]
    metric_per_layer = torch.zeros(num_layers, device=signal.device)
    
    # Pre-invert signal once if needed
    signal_inverted = vit_processor.invert_processor(signal) if vit_processor is not None else signal
    
    # Ensure we have proper reconstructions for ViT model
    using_vit = vit_processor is not None
    
    for l in range(num_layers):
        dict_idx = l if dict.shape[0] > 1 else 0  # single D or multi D mode.

        # Special case for energy and energy_jacobian which have their own reconstruction logic
        if constraint_type == "energy":
            metric_per_layer[l] = evaluate_energy(
                signal, sparse_codes[:, l : l + 1], dict[dict_idx], lambda1, lambda2
            )
        elif constraint_type == "energy_jacobian":
            metric_per_layer[l] = jacobian(
                evaluate_energy,
                (signal, sparse_codes[:, l : l + 1], dict[dict_idx], lambda1, lambda2),
                create_graph=True,
            )[2].norm()
        elif constraint_type == "sparsity":
            metric_per_layer[l] = metrics.sparsity_metric(sparse_codes[:, l])
        else:
            # Get reconstruction for this layer - prioritize stored reconstructions for ViT
            if reconstructions is not None:
                reconstruction = reconstructions[:, l]
            else:
                reconstruction = torch.matmul(sparse_codes[:, l], dict[dict_idx].T)
                
                # Reshape reconstruction if needed (e.g., for ViT with channel dimensions)
                if using_vit and reconstruction.ndim == 3:  # (T,B,F)
                    # Assuming F = C*W*H for ViT
                    T, B, F = reconstruction.shape
                    # Reshape to match signal dimensions
                    reconstruction = reconstruction.view(T, B, *signal.shape[2:])

            if constraint_type in ["none", "loss"]:
                metric_per_layer[l] = metrics.sparse_code_loss(
                    signal, reconstruction, lambda1, sparse_codes=sparse_codes[:, l]
                )
            elif constraint_type == "psnr":
                if using_vit:
                    # For ViT, use the pre-inverted signal and invert the reconstruction
                    reconstruction_inverted = vit_processor.invert_processor(reconstruction)
                    metric_per_layer[l] = metrics.psnr(signal_inverted, reconstruction_inverted)
                else:
                    metric_per_layer[l] = metrics.psnr(signal, reconstruction)
            elif constraint_type == "mse":
                metric_per_layer[l] = metrics.mse(signal, reconstruction)
            else:
                raise ValueError(f"Unsupported constraint type: {constraint_type}")
    
    return metric_per_layer


def evaluate_energy(signal, sparse_codes, dict, lb1, lb2,reconstruction=None):
    """ "
    Computes the energy for a given signal (X), sparse code (of one layer) (H) and dictionary (D)
    Outputs: The energy, as defined by DeWeerdt.

    sparse_codes is expected to be the sparse code of a single layer, but unsqueezed i.e. (T,1,B,H). index by l:l+1.

    """
    if reconstruction is None:
        s = AJ  # (T,1,B,H) x (H,F) = (T,1,B,F)
    else:
        s = reconstruction

    y_t_minus_y_tau = s.permute(2, 3, 0, 1) - s.permute(2, 3, 1, 0)  # y_t - y_{\tau}

    cross_terms = torch.triu(torch.norm(y_t_minus_y_tau, p=2, dim=1) ** 2).mean(
        dim=(1, 2)
    )
    norm_reconstructions = torch.norm(s, p="fro", dim=(0, -1)) ** 2
    loss1 = -torch.exp(-0.5 * cross_terms).mean() + norm_reconstructions.mean()
    loss2 = metrics.sparse_code_loss(
        signal, s[:, 0], lb1, sparse_codes=sparse_codes[:, 0]
    )
    return lb2 * loss1 + loss2


def evaluate_energy_simple(signal, sparse_codes, dict, lb1, lb2):
    """
    Computes a simplified energy without the exponential term (for debugging).
    """
    s = torch.matmul(sparse_codes, dict.T)
    loss = metrics.sparse_code_loss(signal, s, lb1, sparse_codes=sparse_codes)
    return loss + lb2 * (torch.norm(s, p="fro", dim=(0, 2)) ** 2).mean()
