import torch
from torch import nn
from torch.utils.data import DataLoader
from typing import List, Dict, Tuple, Optional

from SNN.Layers import Neuron, ITLIF, ITQIF, LNM

E_MAC = 4.6e-6  # 4.6 uJ
E_AC = 0.9e-6 # 0.9 uJ
E_M = 3.7e-6  # 3.7 uJ

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def approximate_energy_consumption(
    model: nn.Module, dataset: DataLoader, timesteps: int, spike_rates: Optional[List[float]] = None, is_imagenet: bool = False
) -> float:
    """Approximates the energy consumption of a model on a dataset."""
    
    if spike_rates is None:
        model.reset_total_spikes()
        model.eval()
        spike_rates = _calculate_spike_rates(model, dataset, timesteps)
    
    spiking_mac, spiking_ac, spiking_m, ac = _calculate_mac_and_ac_and_m_operations(model, dataset, is_imagenet)
    model_energy = _approximate_model_energy(spike_rates, spiking_mac, spiking_ac, spiking_m, ac, timesteps)
    
    return model_energy

def _approximate_model_energy(
    spike_rates: List[float],
    spiking_mac: Dict,
    spiking_ac: Dict,
    spiking_m: Dict,
    traditional_ac: Dict,
    timesteps: int
) -> float:
    """Approximates the total energy of the model."""
    total_ac = sum(traditional_ac.values())
    total_mac = sum(spiking_mac.values())
    total_m = sum(spiking_m.values())

    neuron_energy = 0.0
    for spike_rate, mac, ac, m in zip(spike_rates, spiking_mac.values(), spiking_ac.values(), spiking_m.values()):
        neuron_energy += timesteps * _approximate_layer_energy(1.0, mac, ac, m)
    
    spike_index = 0
    for i, module in enumerate(traditional_ac.keys()):
        if i == 0:
            neuron_energy += timesteps * _approximate_layer_energy(1.0, traditional_ac[module], None, None)
            total_mac += traditional_ac[module]
        elif isinstance(module, nn.Linear) or getattr(module, 'stride', (1, 1)) == (1, 1):
            if spike_index < len(spike_rates):
                neuron_energy += timesteps * _approximate_layer_energy(spike_rates[spike_index], None, traditional_ac[module], None)
                spike_index += 1
        elif getattr(module, 'stride', (1, 1)) == (2, 2):
            neuron_energy += timesteps * _approximate_layer_energy(spike_rates[spike_index-1], None, traditional_ac[module], None)

    print(f"Total MAC: {total_mac}, Total AC: {total_ac}, Total M: {total_m}")

    return neuron_energy

def _approximate_layer_energy(spike_rate: Optional[float], mac: Optional[float], ac: Optional[float], m: Optional[float]) -> float:
    return _calculate_mac_cost(mac, spike_rate) + _calculate_ac_cost(ac, spike_rate) + _calculate_m_cost(m, spike_rate)

def _calculate_mac_cost(mac: Optional[float], spike_rate: Optional[float]) -> float:
    return mac * E_MAC * (spike_rate if spike_rate else 1.0) if mac else 0.0

def _calculate_ac_cost(ac: Optional[float], spike_rate: Optional[float]) -> float:
    return ac * E_AC * (spike_rate if spike_rate else 1.0) if ac else 0.0

def _calculate_m_cost(m: Optional[float], spike_rate: Optional[float]) -> float:
    return m * E_M * (spike_rate if spike_rate else 1.0) if m else 0.0

def _calculate_spike_rates(model: nn.Module, dataloader: DataLoader, timesteps: int) -> List[float]:
    """Calculates the average spike rate per image per timestep."""
    spike_rates = []
    model.to(DEVICE)
    
    for data, _ in dataloader:
        data = data.float().to(DEVICE)
        model(data)
        break
    
    for module in model.modules():
        if isinstance(module, Neuron):
            spike_rates.append(module.spike_rate)
    
    return spike_rates

def _calculate_mac_and_ac_and_m_operations(model: nn.Module, dataloader: DataLoader, is_imagenet: bool = False) -> Tuple[Dict, Dict, Dict, Dict]:
    """Calculates MAC, AC, and other operations for energy estimation."""
    spiking_mac, spiking_ac, spiking_m, ac = {}, {}, {}, {}
    
    def hook(module: nn.Module, input: torch.Tensor, output: torch.Tensor):
        input = input[0]  # Unpack tuple if needed
        if isinstance(module, nn.Conv2d):
            weight_shape = module.weight.shape
            # K_h * k_w * in_channels * out_channels * h_out * w_out
            ac[module] = weight_shape[2] * weight_shape[3] * weight_shape[1] * weight_shape[0] * output.shape[3] * output.shape[2] 
        elif isinstance(module, nn.Linear) or isinstance(module, nn.LazyLinear):
            # in_features * out_features
            ac[module] = input.shape[1] * output.shape[1]
        elif isinstance(module, ITLIF):
            spiking_mac[module] = input.shape[1:].numel()
            spiking_ac[module] = 0.0
            spiking_m[module] = 0.0
        elif isinstance(module, ITQIF):
            spiking_mac[module] = input.shape[1:].numel() * 2
            spiking_m[module] = 0.0
            spiking_ac[module] = 0.0
        elif isinstance(module, LNM):
            poly_degree = module.update.poly_degree
            spiking_mac[module] = 0.0
            spiking_ac[module] = input.shape[1:].numel() * (poly_degree - 1)
            spiking_m[module] = input.shape[1:].numel() * (poly_degree)
            # spiking_mac[module] = input.shape[1:].numel()
            # spiking_ac[module] = 0.0
            # spiking_m[module] = 0.0
    
    hooks = [module.register_forward_hook(hook) for module in model.modules()]
    
    model.to(DEVICE)

    if is_imagenet:
        for data in dataloader:
            data = data['image'].float().to(DEVICE)
            model(data)
            break
    else:
        for data, _ in dataloader:
            data = data.float().to(DEVICE)
            model(data)
            break
    
    for hook in hooks:
        hook.remove()
    return spiking_mac, spiking_ac, spiking_m, ac