import torch
import torch.nn as nn
import pandas as pd
import json
import numpy as np
import glob
import statistics
import os


def debug(msg):
    print(f"[DEBUG] {msg}")


class OnlyFeatures(nn.Module):
    """Wrapper that forwards only through .features (for VGG-like models)."""
    def __init__(self, features):
        super().__init__()
        self.features = features
    def forward(self, x):
        return self.features(x)

def looks_like_vgg(keys):
    # Heuristic: CSV has many 'features.<idx>' and not 'layer1.x.conv1'
    return any(k.startswith("features.") for k in keys) and not any("layer1." in k for k in keys)

def build_model_for_csv_keys(csv_keys):
    """
    Auto-select a model to generate layer_info that matches CSV layer names.
    Returns (model, input_size, arch_tag)
    """
    if looks_like_vgg(csv_keys):
        debug("Detected VGG-like CSV (features.*). Using VGG16-BN features-only.")
        import torchvision.models as tvm
        vgg = tvm.vgg16_bn(weights=None)  # weights not required for shape hooks
        model = OnlyFeatures(vgg.features)  # forward only features; avoids classifier mismatch on CIFAR
        input_size = (1, 3, 32, 32)
        arch_tag = "vgg16_bn_features"
    else:
        debug("Detected non-VGG CSV. Using WideResNet28-10.")
        # Your custom WideResNet matches your original codebase
        from src.models import WideResNet
        model = WideResNet(depth=28, widen_factor=10, num_classes=10)
        input_size = (1, 3, 32, 32)
        arch_tag = "wrn28x10"
    return model, input_size, arch_tag


def get_layer_info(model, input_size=(1, 3, 32, 32), json_path=None):
    layer_info = {}
    hooks = []

    def register_hook(layer_name, layer):
        def hook(module, input, output):
            if isinstance(module, nn.Conv2d):
                N, C_in, H_in, W_in = input[0].shape
                N, C_out, H_out, W_out = output.shape
                layer_info[layer_name] = {
                    "type": "conv",
                    "C_in": C_in,
                    "C_out": C_out,
                    "kernel": [module.kernel_size[0], module.kernel_size[1]],
                    "stride": [module.stride[0], module.stride[1]],
                    "padding": [module.padding[0], module.padding[1]],
                    "input_spatial": [H_in, W_in],
                    "use_bias": module.bias is not None
                }
            elif isinstance(module, nn.Linear):
                in_f = module.in_features
                out_f = module.out_features
                layer_info[layer_name] = {
                    "type": "linear",
                    "in_features": in_f,
                    "out_features": out_f,
                    "use_bias": module.bias is not None
                }
        return hook

    for name, module in model.named_modules():
        if isinstance(module, (nn.Conv2d, nn.Linear)):
            hooks.append(module.register_forward_hook(register_hook(name, module)))

    dummy_input = torch.randn(*input_size)
    model.eval()
    with torch.no_grad():
        _ = model(dummy_input)

    for h in hooks:
        h.remove()

    if json_path:
        with open(json_path, "w") as f:
            json.dump(layer_info, f, indent=2)

    debug(f"Hooked layers: {list(layer_info.keys())}")
    return layer_info


def get_conv_output_size(image_size, filter_size, padding, stride):
    """
    Compute output size for a 2D convolution (height, width separately).
    padding: int, (ph, pw), or 'same'/'valid'
    stride: int or (sh, sw)
    """
    # Padding
    if isinstance(padding, str):
        if padding == 'same':
            pad_h = pad_w = filter_size // 2
        elif padding == 'valid':
            pad_h = pad_w = 0
        else:
            raise NotImplementedError(f"Padding {padding} not supported")
    elif isinstance(padding, (tuple, list)):
        pad_h, pad_w = padding
    elif isinstance(padding, int):
        pad_h = pad_w = padding
    else:
        raise ValueError(f"Unsupported padding type: {type(padding)}")

    # Stride
    if isinstance(stride, (tuple, list)):
        stride_h, stride_w = stride
    else:
        stride_h = stride_w = stride

    H_out = np.floor((image_size[0] + 2 * pad_h - filter_size) / stride_h) + 1
    W_out = np.floor((image_size[1] + 2 * pad_w - filter_size) / stride_w) + 1
    return int(H_out), int(W_out)

def conv_flops(layer, sparsity, use_bias=True, activation="relu"):
    k_h, k_w = layer["kernel"]
    assert k_h == k_w, "Assuming square kernels"
    C_in, C_out = layer["C_in"], layer["C_out"]
    H_in, W_in = layer["input_spatial"]
    stride = layer.get("stride", 1)
    padding = layer.get("padding", "same")

    # Effective kernel length (nonzeros)
    vector_length = k_h * k_w * C_in * (1 - sparsity)

    # Output spatial size
    H_out, W_out = get_conv_output_size((H_in, W_in), k_h, padding, stride)
    n_output_elements = H_out * W_out * C_out

    # FLOPs: dot products
    flop_mults = vector_length * n_output_elements
    flop_adds = (vector_length - 1) * n_output_elements

    # Bias adds
    if use_bias and layer.get("use_bias", False):
        flop_adds += n_output_elements

    # Activation cost (ReLU ~ 1 add per element)
    if activation is not None and activation.lower() == "relu":
        flop_adds += n_output_elements

    return float(flop_mults + flop_adds)

def linear_flops(layer, sparsity):
    in_f, out_f = layer["in_features"], layer["out_features"]
    active_in = in_f * (1 - sparsity)
    return float(active_in * out_f * 2)

def compute_sparse_forward_flops(layer_info, sparsities):
    total = 0.0
    for name, info in layer_info.items():
        if name not in sparsities:
            continue
        s = float(sparsities[name])
        if info["type"] == "conv":
            total += conv_flops(info, s)
        elif info["type"] == "linear":
            total += linear_flops(info, s)
    return total


def training_flops_per_epoch(layer_info, sparsity_dict, steps_per_epoch, grad_every=100):
    """
    FLOP accounting with occasional full-gradient steps.

    - Masked gradient step: 3 * f_S
    - Full gradient step:  2 * f_S + f_D
    """
    epoch_flops = []
    for epoch, sparsities in sparsity_dict.items():
        f_s = compute_sparse_forward_flops(layer_info, sparsities)
        if f_s == 0.0:
            debug(f"Epoch {epoch}: f_S is zero (no matching layers?).")
        # dense = zero sparsity for the same keys we have for this epoch
        dense_zero = {k: 0.0 for k in sparsities.keys()}
        f_d = compute_sparse_forward_flops(layer_info, dense_zero)

        masked_cost = 3 * f_s
        full_cost = 2 * f_s + f_d

        total = 0.0
        for step in range(steps_per_epoch):
            if (step + 1) % grad_every == 0:
                total += full_cost
            else:
                total += masked_cost

        epoch_flops.append((epoch, total))
    return epoch_flops

def load_sparsity_csv(path):
    df = pd.read_csv(path)

    # Keep only columns that are per-parameter sparsities:
    # Val/layer_sparsity_<layer_name>.weight
    sparsity_cols = [c for c in df.columns
                     if c.startswith("Val/layer_sparsity_") and c.endswith(".weight")]

    if not sparsity_cols:
        raise RuntimeError(f"No per-layer sparsity columns found in {path}")

    # Drop rows where all layer sparsity columns are NaN (common at tail)
    df = df.dropna(subset=sparsity_cols, how="all")

    # Convert % -> [0,1]
    for c in sparsity_cols:
        df[c] = df[c] / 100.0

    # Keep step + sparsities
    keep_cols = ["step"] + sparsity_cols
    df = df[keep_cols]

    # Build dict: epoch -> {layer_name: sparsity}
    sparsity_dict = {}
    for _, row in df.iterrows():
        epoch = int(row["step"])
        layer_sparsities = {}
        for c in sparsity_cols:
            name = c.replace("Val/layer_sparsity_", "").replace(".weight", "")
            val = row[c]
            if pd.isna(val):
                continue
            layer_sparsities[name] = float(val)
        sparsity_dict[epoch] = layer_sparsities

    # For diagnostics
    if sparsity_dict:
        example_epoch = next(iter(sparsity_dict))
        debug(f"CSV example keys ({os.path.basename(path)}): {sorted(list(sparsity_dict[example_epoch].keys()))[:10]} ...")
    return sparsity_dict

def align_sparsity_to_layerinfo(sparsity_dict, layer_info):
    """
    Keep only keys that exist in layer_info. Warn if none match.
    """
    li_keys = set(layer_info.keys())
    aligned = {}
    for epoch, sp in sparsity_dict.items():
        # exact intersection
        aligned_sp = {k: v for k, v in sp.items() if k in li_keys}
        if not aligned_sp:
            # Try suffix-based fallback (best-effort), but only if unique
            fallback = {}
            for sk, v in sp.items():
                matches = [lk for lk in li_keys if lk.endswith(sk)]
                if len(matches) == 1:
                    fallback[matches[0]] = v
            if fallback:
                debug(f"Epoch {epoch}: using suffix fallback for some keys "
                      f"(matched {len(fallback)} of {len(sp)}).")
                aligned_sp = fallback

        if not aligned_sp:
            debug(f"[WARNING] Epoch {epoch}: No matching keys between CSV and layer_info. "
                  f"CSV keys (sample): {list(sp.keys())[:5]} | layer_info keys (sample): {list(li_keys)[:5]}")
        aligned[epoch] = aligned_sp
    return aligned


if __name__ == "__main__":
    ROOT = './runs/CIFAR10/ResNet18/LinBregSparse0007/'
    csv_files = glob.glob(os.path.join(ROOT, '*.csv'))
    if not csv_files:
        raise RuntimeError(f"No CSVs found under {ROOT}")

    first_csv = csv_files[0]
    sp0 = load_sparsity_csv(first_csv)
    first_keys = set(next(iter(sp0.values())).keys()) if sp0 else set()

    model, input_size, arch_tag = build_model_for_csv_keys(first_keys)
    layer_info = get_layer_info(model, input_size=input_size, json_path=None)

    average_flops_list = []

    # CIFAR-10 default split (45k train if 90% of 50k)
    train_size = 45000
    batch_size = 128
    steps_per_epoch = int(np.ceil(train_size / batch_size))

    for path in csv_files:
        print(f"\nProcessing: {os.path.basename(path)}")
        sparsity_dict = load_sparsity_csv(path)
        sparsity_dict = align_sparsity_to_layerinfo(sparsity_dict, layer_info)

        # Sanity: if *all* epochs have zero keys -> warn & skip
        if all(len(v) == 0 for v in sparsity_dict.values()):
            print(f"[ERROR] No aligned layers for {os.path.basename(path)}. "
                  f"Model/CSV mismatch (arch_tag={arch_tag}). Skipping.")
            continue

        # FLOPs per epoch
        flops = training_flops_per_epoch(layer_info, sparsity_dict, steps_per_epoch, grad_every=100) # need to change grad_every to 1 for LinBreg and SGD

        # Average over epochs
        epoch_totals = [f for _, f in flops]
        if not epoch_totals:
            print(f"[WARNING] No epochs found with valid sparsities in {os.path.basename(path)}.")
            continue

        avg = float(np.mean(epoch_totals))
        print(f"Average FLOPs per epoch: {avg/1e9:.2f} GFLOPs")
        average_flops_list.append(avg)

    if average_flops_list:
        overall = statistics.mean(average_flops_list)
        print(f"\nOverall average FLOPs: {overall/1e9:.2f} GFLOPs")
    else:
        print("\nNo valid CSVs produced aligned FLOPs. See debug warnings above.")
