import torch
from torch import nn

from torch.utils.data import DataLoader

from SNN.Layers import Neuron, ITLIF, ITQIF

E_MAC = 4.6 * 10**(-12) # 4.6 pJ
E_AC = 0.9 * 10**(-12) # 0.9 pJ

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def approximate_energy_consumption(model: nn.Module, dataset: DataLoader, timesteps: int, spike_rates: list | None) -> float:
    """
    Approximate the energy consumption of a model on a given dataset.
    
    Args:
        model: The model to approximate the energy consumption of.
        dataset: The dataset to approximate the energy consumption on.
        
    Returns:
        The approximate energy consumption of the model on the dataset in millijoules (mJ).
    """
    model.reset_total_spikes()
    model.eval()

    if spike_rates is None:
        spike_rates = _calculate_spike_rates(model, dataset, timesteps)
    spiking_mac, spiking_ac, ac = _calculate_mac_and_ac_operations(model, dataset)

    model_energy = _approximate_model_energy(spike_rates, spiking_mac, spiking_ac, ac, timesteps)
    # convert picojoules to millijoules
    model_energy /= 10**(-3)

    return model_energy

def _approximate_model_energy(spike_rates: list[float], spiking_mac: dict, spiking_ac: dict, traditional_ac: dict, timesteps: int) -> float:
    neuron_energy = 0.0
    for spike_rate, mac, ac in zip(spike_rates, spiking_mac.values(), spiking_ac.values()):
        neuron_energy += timesteps * _approximate_layer_energy(spike_rate, mac, ac)

    spike_index = 0
    for i, module in enumerate(traditional_ac.keys()):
        if i == 0:
            # The first layer does not have a spike rate
            neuron_energy += timesteps * _approximate_layer_energy(1.0, None, traditional_ac[module])
        elif isinstance(module, nn.Linear) or module.stride == (1, 1):
            if spike_index == len(spike_rates):
                # The last layer does not have a spike rate
                continue
            neuron_energy += timesteps * _approximate_layer_energy(spike_rates[spike_index], None, traditional_ac[module])
            spike_index += 1
        elif module.stride == (2, 2):
            neuron_energy += timesteps * _approximate_layer_energy(spike_rates[spike_index-1], None, traditional_ac[module])

    return neuron_energy

def _approximate_layer_energy(spike_rate: float | None, mac: float | None, ac: float):
    return _calculate_mac_cost(mac) + _calculate_ac_cost(ac, spike_rate)

def _calculate_mac_cost(mac: float | None):
    if mac is None:
        return 0.0
    return mac * E_MAC

def _calculate_ac_cost(ac: float, spike_rate: float):
    return ac * E_AC * spike_rate

def _calculate_spike_rates(model: nn.Module, dataloader: DataLoader, timesteps: int) -> dict:
    """
    Calculate the average number of spikes per image per timestep.
    
    Args:
        model: The model to calculate the spike rates of.
        dataset: The dataset to calculate the spike rates on.
        timesteps: The number of timesteps to simulate the model for.
    """
    # Register the spike rate hook on all Neuron layers
    spike_rates = []

    model.to(DEVICE)
    for data, _ in dataloader:
        data = data.float().to(DEVICE)
        model(data)
        break
    
    # Remove the spike rate hook from all Neuron layers
    for module in model.modules():
        if isinstance(module, Neuron):
            spike_rates.append(module.spike_rate)

    return spike_rates

def _calculate_mac_and_ac_operations(model: nn.Module, dataloader: DataLoader) -> tuple[dict, dict, dict]:
    """
    Calculate the spike rates of a model on a given dataset.
    
    Args:
        model: The model to calculate the spike rates of.
        dataset: The dataset to calculate the spike rates on.
    """
    # Register the spike rate hook on all Neuron layers
    spiking_mac = dict()
    spiking_ac = dict()
    ac = dict()

    def hook(module, input, output):
        if isinstance(module, nn.Conv2d):
            weight_shape = module.weight.size()
            # AC operations = H_out * W_out * C_out * K_h * K_w * C_in
            ac[module] = output.shape[-1] * output.shape[-2] * output.shape[-3] * weight_shape[-1] * weight_shape[-2] * input[0].shape[-3]
        elif isinstance(module, nn.Linear):
            ac[module] = input[0].shape[-1] * output.shape[-1]
        elif isinstance(module, ITLIF):
            # Input is a (data, state) tuple
            input = input[0]
            if input.ndim == 4:
                spiking_mac[module] = input[0].shape[-1] * input[0].shape[-2] * input[0].shape[-3]
                spiking_ac[module] = 0.0
            elif input.ndim == 2:
                spiking_mac[module] = input[0].shape[-1]
                spiking_ac[module] = 0.0
        elif isinstance(module, ITQIF):
            # Input is a (data, state) tuple
            input = input[0]
            if input.ndim == 4:
                spiking_mac[module] = input[0].shape[-1] * input[0].shape[-2] * input[0].shape[-3] * 2
                spiking_ac[module] = spiking_mac[module]
            elif input.ndim == 2:
                spiking_mac[module] = input[0].shape[-1] * 2
                spiking_ac[module] = spiking_mac[module]
    
    for module in model.modules():
        module.register_forward_hook(hook)

    model.to(DEVICE)
    for data, _ in dataloader:
        data = data.float().to(DEVICE)
        model(data)
        break
    
    # Remove the spike rate hook from all Neuron layers
    for module in model.modules():
        if isinstance(module, Neuron):
            module._forward_hooks.clear()

    return spiking_mac, spiking_ac, ac