"""
Utility functions for PCN visualization and logging
Contains free energy calculation functions for monitoring and plotting
"""

import torch
import torch.nn as nn
import numpy as np


def calculate_vanilla_pc_free_energy(pcn_model, zs_t, x, y):
    """
    Calculate vanilla PC free energy for logging/visualization purposes
    
    Args:
        pcn_model: PCN model instance
        zs_t: Current latent states
        x: Input tensor
        y: Target tensor
        
    Returns:
        List of free energy values per layer (as scalars)
    """
    vanilla_pc_free_energy = []
    vanilla_pc_free_energy.append((torch.sum((zs_t[0] - x)**2) / x.shape[0]).item())
    for idx, backbone_module in enumerate(pcn_model.backbone_module_list):
        if idx != len(pcn_model.backbone_module_list) - 1:
            vanilla_pc_free_energy.append((torch.sum((zs_t[idx+1] - backbone_module(zs_t[idx]))**2) / x.shape[0]).item())
        else:
            vanilla_pc_free_energy.append((torch.nn.functional.cross_entropy(backbone_module(zs_t[-2]), y)).item())
    return vanilla_pc_free_energy


def calculate_meta_pc_free_energy(pcn_model, zs_t, x, y, zhs_t):
    """
    Calculate free energy for Meta PC (with frozen predictions) for logging/visualization
    Memory-optimized version
    
    Args:
        pcn_model: PCN model instance
        zs_t: Current latent states
        x: Input tensor
        y: Target tensor
        zhs_t: Frozen prediction states
        
    Returns:
        List of meta PC free energy values per layer (as scalars)
    """
    meta_pc_free_energy = []

    with torch.enable_grad():
        # Only create necessary tensors
        zs_state_t = [z.clone().detach().requires_grad_(True) for z in zs_t]
        zhs_state_t = [z.clone().detach().requires_grad_(True) for z in zhs_t]

        # Calculate predictions without storing intermediate results
        predictions = []
        for idx, backbone_module in enumerate(pcn_model.backbone_module_list):
            pred = backbone_module(zhs_state_t[idx])
            predictions.append(pred)

        # Calculate loss for gradient computation
        pc_loss = torch.sum((zs_state_t[0] - x)**2) / x.shape[0]
        for idx, backbone_module in enumerate(pcn_model.backbone_module_list):
            if idx != len(pcn_model.backbone_module_list) - 1:
                pred_value = predictions[idx]
                pc_loss += torch.sum((zs_state_t[idx+1] - pred_value)**2) / x.shape[0]
            else:
                final_pred = predictions[-1]
                pc_loss += torch.nn.functional.cross_entropy(final_pred, y)

        pc_loss.backward()

        # Calculate meta PC free energy components and convert to scalars immediately
        for idx in range(len(pcn_model.backbone_module_list)):
            if idx < len(zs_state_t) - 1:
                # Calculate delta_tilda and delta_tilda_pred for this layer only
                delta_tilda_layer = zs_state_t[idx+1] - predictions[idx]
                grad_value = (zs_state_t[idx+1].grad + zhs_state_t[idx+1].grad) if idx+1 < len(zs_state_t) - 1 else torch.zeros_like(zs_state_t[idx+1])

                # Convert to scalar immediately to save memory
                meta_energy = torch.sum((delta_tilda_layer - grad_value)**2) / x.shape[0]
                meta_pc_free_energy.append(meta_energy.item())
            else:
                # For the last layer, just append 0
                meta_pc_free_energy.append(0.0)

    return meta_pc_free_energy


def calculate_meta_pc_free_energy_per_sample(pcn_model, zs_t, x, y, zhs_t):
    """
    Calculate meta PC free energy per sample for detailed analysis
    
    Args:
        pcn_model: PCN model instance
        zs_t: Current latent states
        x: Input tensor
        y: Target tensor
        zhs_t: Frozen prediction states
        
    Returns:
        numpy array of shape [layer_num, batch_size]
    """
    meta_pc_free_energy_per_sample = []

    with torch.enable_grad():
        # Only create necessary tensors
        zs_state_t = [z.clone().detach().requires_grad_(True) for z in zs_t]
        zhs_state_t = [z.clone().detach().requires_grad_(True) for z in zhs_t]

        # Calculate predictions without storing intermediate results
        predictions = []
        for idx, backbone_module in enumerate(pcn_model.backbone_module_list):
            pred = backbone_module(zhs_state_t[idx])
            predictions.append(pred)

        # Calculate loss for gradient computation
        pc_loss = torch.sum((zs_state_t[0] - x)**2) / x.shape[0]
        for idx, backbone_module in enumerate(pcn_model.backbone_module_list):
            if idx != len(pcn_model.backbone_module_list) - 1:
                pred_value = predictions[idx]
                pc_loss += torch.sum((zs_state_t[idx+1] - pred_value)**2) / x.shape[0]
            else:
                final_pred = predictions[-1]
                pc_loss += torch.nn.functional.cross_entropy(final_pred, y)

        pc_loss.backward()

        # Calculate meta PC free energy components per sample
        for idx in range(len(pcn_model.backbone_module_list)):
            if idx < len(zs_state_t) - 1:
                # Calculate delta_tilda and delta_tilda_pred for this layer only
                delta_tilda_layer = zs_state_t[idx+1] - predictions[idx]
                grad_value = (zs_state_t[idx+1].grad + zhs_state_t[idx+1].grad) if idx+1 < len(zs_state_t) - 1 else torch.zeros_like(zs_state_t[idx+1])

                # Calculate per-sample energy (sum over spatial dimensions, keep batch dimension)
                meta_energy_per_sample = torch.sum((delta_tilda_layer - grad_value)**2, dim=tuple(range(1, len(delta_tilda_layer.shape))))
                meta_pc_free_energy_per_sample.append(meta_energy_per_sample)
            else:
                # For the last layer, append zeros with batch_size shape
                meta_pc_free_energy_per_sample.append(torch.zeros(x.shape[0]))

    # Stack to [layer_num, batch_size] shape and convert to numpy
    return torch.stack(meta_pc_free_energy_per_sample).numpy()


def calculate_vanilla_pc_free_energy_per_sample(pcn_model, zs_t, x, y):
    """
    Calculate vanilla PC free energy per sample for detailed analysis
    
    Args:
        pcn_model: PCN model instance
        zs_t: Current latent states
        x: Input tensor
        y: Target tensor
        
    Returns:
        numpy array of shape [layer_num, batch_size]
    """
    vanilla_pc_free_energy_per_sample = []
    # [batch_size] for input reconstruction error
    vanilla_pc_free_energy_per_sample.append(torch.sum((zs_t[0] - x)**2, dim=tuple(range(1, len(x.shape)))))
    for idx, backbone_module in enumerate(pcn_model.backbone_module_list):
        if idx != len(pcn_model.backbone_module_list) - 1:
            # [batch_size] for each layer prediction error
            pred_error = torch.sum((zs_t[idx+1] - backbone_module(zs_t[idx]))**2, dim=tuple(range(1, len(zs_t[idx+1].shape))))
            vanilla_pc_free_energy_per_sample.append(pred_error)
        else:
            # [batch_size] for classification error
            vanilla_pc_free_energy_per_sample.append(torch.nn.functional.cross_entropy(backbone_module(zs_t[-2]), y, reduction='none'))
    # Stack to [layer_num, batch_size] shape
    return torch.stack(vanilla_pc_free_energy_per_sample).detach().cpu().numpy()