import torch
import torch.nn as nn
import numpy as np
from typing import List, Tuple
from Models.load_data import load_data_mlp
import Models
import sys
from scipy.io import loadmat, savemat
import argparse

class EngModule(nn.Module):
    """
    Wrapper around an nn.Conv2d or nn.Linear that:
      - delegates forward to the original module
      - captures the output shape (out_shape) during forward
      - keeps a reference to the original module as .module
    """
    def __init__(self, module: nn.Module):
        super().__init__()
        self.module = module
        self.out_shape = None  # filled after a forward pass

    def forward(self, x):
        out = self.module(x)
        # store shape as plain tuple for later use (move to cpu if needed)
        if isinstance(out, torch.Tensor):
            self.out_shape = tuple(out.shape)
        else:
            # if module returns tuple/list (rare), try to handle first tensor
            try:
                t = out[0]
                if isinstance(t, torch.Tensor):
                    self.out_shape = tuple(t.shape)
            except Exception:
                self.out_shape = None
        return out

def wrap_model_with_eng(module: nn.Module) -> nn.Module:
    """
    Recursively replace all nn.Conv2d and nn.Linear modules with EngModule wrapper,
    in-place. Returns the (modified) root module.
    Uses module._modules[...] assignment which is safe for Sequential/ModuleList etc.
    """
    for name, child in list(module.named_children()):
        # if already wrapped, skip
        if isinstance(child, EngModule):
            continue
        # If the child itself is a Conv2d or Linear, replace with EngModule
        if isinstance(child, (nn.Conv2d, nn.Linear)):
            module._modules[name] = EngModule(child)
        else:
            # Recurse into child
            wrap_model_with_eng(child)
    return module

def _collect_flops_in_named_order(module: nn.Module) -> List[float]:
    """
    Traverse module using named_children recursion (same style as get_firing_rate),
    and collect FLOPs (per single ANN forward pass) for each Conv2d/Linear encountered,
    in the same order. Return list of FLOPs (floats).
    FLOP formula:
      - Conv2d: k_h * k_w * out_h * out_w * (in_channels/groups) * out_channels
      - Linear: in_features * out_features
    NOTE: this relies on EngModule.out_shape being populated (so call forward with dummy_input first).
    """
    flops = []

    for name, child in module.named_children():
        # If child is an EngModule that wraps Conv/Linear, compute its flops from attributes
        if isinstance(child, EngModule):
            orig = child.module
            if isinstance(orig, nn.Conv2d):
                if child.out_shape is None:
                    raise RuntimeError(f"EngModule for Conv2d '{name}' has no out_shape. "
                                       "Run a forward pass with dummy_input first.")
                # out_shape: (batch, out_channels, h, w)
                _, cout, h, w = child.out_shape
                k_h, k_w = orig.kernel_size
                cin = orig.in_channels
                groups = orig.groups if hasattr(orig, "groups") else 1
                # effective in-channels per group = cin / groups
                fl = float(k_h * k_w) * float(h) * float(w) * (cin / groups) * cout
                flops.append(fl)
            elif isinstance(orig, nn.Linear):
                if child.out_shape is None:
                    raise RuntimeError(f"EngModule for Linear '{name}' has no out_shape. "
                                       "Run a forward pass with dummy_input first.")
                # out_shape: (batch, out_features)
                in_feat = orig.in_features
                out_feat = orig.out_features
                fl = float(in_feat) * float(out_feat)
                flops.append(fl)
            else:
                # wrapper around something unexpected (shouldn't happen)
                pass
        else:
            # Recurse into non-wrapped child
            child_flops = _collect_flops_in_named_order(child)
            if child_flops:
                flops.extend(child_flops)
    return flops

def compute_snn_ann_energy_ratio(
    ts,
    LASFR,
    dense_allocation: float,
    model: nn.Module,
    dummy_input: torch.Tensor,
    device: torch.device,
    once_input=False ,
    E_MAC: float = 4.6e-12,
    E_AC: float = 0.9e-12,
) -> dict:
    """
    Compute SNN per-time-step energy (array of shape [T]) and ANN energy (scalar),
    and store results back into data_dict.

    Inputs:
      - data_dict: dict loaded from res.mat; must contain key 'LASFR' shaped [num_activations, T].
                   LASFR[i, t] is the firing rate of activation i at timestep t.
                   Important: LASFR is measured *starting from the first hidden-layer activation*,
                   i.e., it does NOT include the input encoder first layer. (Per your description.)
      - model: the nn.Module instance (will be wrapped in EngModule internally).
      - dummy_input: a tensor with batch dim = 1 and correct input shape (e.g., torch.randn(1,3,32,32)).
      - device: torch.device('cpu') or torch.device('cuda:0'), etc.
      - E_MAC, E_AC: energy per operation (J). Defaults follow the paper: 4.6 pJ and 0.9 pJ.

    Returns:
      - data_dict (modified): adds keys
          'energy_ann' : scalar (J)
          'energy_snn' : numpy array with shape (T,) (J at each timestep)
          'ratio_snn_ann' : numpy array with shape (T,) (energy_snn[t] / energy_ann)
          'flops_per_layer' : python list of per-layer FLOPs (ANN single-forward)
    """
    # 0) wrap model (in place)
    model = wrap_model_with_eng(model)

    # 1) move to device and run a dummy forward to populate EngModule.out_shape
    model = model.to(device)
    dummy_input = dummy_input.to(device)
    model.eval()
    with torch.no_grad():
        _ = model(dummy_input)

    # 2) collect flops in named_children recursive order (same style as get_firing_rate)
    flops_list = _collect_flops_in_named_order(model)
    L = len(flops_list)
    if L == 0:
        raise RuntimeError("No Conv2d/Linear layers found (flops_list is empty).")

    # 3) ANN energy (single forward pass)
    total_flops_ann = float(sum(flops_list))

    # 4) process LASFR (now passed in)
    lasfr = np.asarray(LASFR)  # convert to numpy
    if lasfr.ndim != 2:
        raise ValueError(f"LASFR must be 2D (num_activations, T). Got shape {lasfr.shape}")

    K, T = lasfr.shape  # K = number of recorded activations (starts from first hidden layer)

    # ts -> ts array, must match T
    ts = np.asarray(ts, dtype=float)
    if ts.ndim != 1 or ts.shape[0] != T:
        #raise ValueError(f"Length of ts ({None if ts is None else ts.shape[0]}) must equal LASFR time dimension T ({T}).")
        print(f"Length of ts ({ts.shape[0]}) must equal LASFR time stps ({T}).")
        return {'energy_snn': np.zeros_like(ts)}

    # Sanity check same as before
    if not (K == L - 1 or K == L):
        # still continue but warn (we don't raise, to be robust)
        print(f"Warning: number of LASFR rows {K} does not equal expected L-1 ({L-1}) "
              f"or L ({L}). Proceeding with best-effort mapping.")

    # 5) compute SNN energy per timestep (vectorized)
    energy_snn_t = np.zeros_like(ts)

    # Contribution of first layer (index 0 in flops_list): MAC, phi=1 for all timesteps
    flops_first = float(flops_list[0])
    if once_input:
        energy_snn_t += E_MAC * flops_first * dense_allocation
    else:
        energy_snn_t += E_MAC * flops_first * ts * dense_allocation

    # For remaining layers j = 1..L-1, contribution depends on LASFR of previous layer: lasfr[j-1, t]
    for j in range(1, L):
        fl = float(flops_list[j])
        lasfr_idx = j - 1
        if lasfr_idx < K:
            sparse_factor = dense_allocation if j!=L-1 else 1 
            phi_t = lasfr[lasfr_idx, :]  # shape (len(ts),)
            phi_t = np.asarray(phi_t, dtype=float)
            energy_snn_t += E_AC * fl * ts * phi_t * sparse_factor
        else:
            # missing LASFR measurement for this layer's previous activation
            print(f"Warning: missing LASFR for layer index j={j}: expected lasfr_idx={lasfr_idx} < {K}.")

    # 7) prepare result dict (includes LASFR and ts for traceability)
    out = {}
    out['energy_snn'] = np.asarray(energy_snn_t)
    out['flops_per_layer'] = flops_list

    return out


def init_energy(method, architecture, dataset, device):
    '''
    initialize [model, dummy_input and device] for energy computation.
    '''
    args=argparse.Namespace()
    args.architecture=architecture
    args.dataset=dataset
    args.bs=2
    args.dim = 1 if 'CIFAR' in args.dataset else 2
    args.one_fc=True
    device = torch.device(device)
    args.conv_sparsity=0.0
    args.dropout=0.0

    #if method == 'QCFS' or method == 'AEC':
    model, num_activations, train_loader,test_loader = Models.prepare_model_and_loader(args)
    #elif method == 'SNM':
        #model, num_activations, train_loader,test_loader = SNM_Models.prepare_model_and_loader(args)
    #else: raise NotImplementedError


    model = model.to(device)

    with torch.no_grad():
        for inputs, target in train_loader:
            dummy_input=inputs.to(device)
            break
    
    return model, dummy_input, device

def energy_ratio(data_dict, method, architecture, dataset, device):
    '''
    Returns:
      - data_dict (modified): adds keys
          'energy_ann' : scalar (J)
          'energy_snn' : numpy array with shape (T,) (J at each timestep)
          'ratio_snn_ann' : numpy array with shape (T,) (energy_snn[t] / energy_ann)
          'flops_per_layer' : python list of per-layer FLOPs (ANN single-forward)
    '''
    model, dummy_input, device = init_energy(method, architecture, dataset, device)

    return compute_snn_ann_energy_ratio(data_dict, model, dummy_input , device)