import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
from torchvision import transforms
import numpy as np
import os

from fvcore.nn import FlopCountAnalysis, flop_count_table

from modules.transformations import DataTransforms
from modules import get_resnet, get_resnet_spiking, modify_resnet_model
from modules.spike_layer import MixedLIF, LIFt, LIF
from spikformer import spikformer
# Assume DataTransforms and model building utils are available
# from your project imports:
# from transforms import DataTransforms
# from models import get_resnet_spiking, get_vgg_spiking, get_spikformer
# from training import yaml_config_hook

# For demo purpose, define simple DataTransforms for CIFAR10
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Create CIFAR-10 DataLoader
train_dataset = torchvision.datasets.CIFAR10(
    root='/mnt/vstor/CSE_ECSE_GXD234/data', train=True, download=True, transform=DataTransforms(size=32)
)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=False, num_workers=4)

# Example args mimic parsed config
class Args:
    timestep = 4
    logistic_batch_size = 512  # same as dataloader batch_size
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # Model-specific args ...
args = Args()

# Build models
# Replace with your actual model init functions and checkpoint paths
spiking_resnet = get_resnet_spiking('resnet34', args.timestep, sync_norm=False, act_func=LIF, num_classes=10)
spiking_resnet = modify_resnet_model(spiking_resnet)
spikformer = spikformer(
        drop_rate=0.,
        drop_path_rate=0.,
        drop_block_rate=None,
        img_size_h=32, img_size_w=32,
        patch_size=4, embed_dims=384, num_heads=12, mlp_ratios=4,
        in_channels=3, num_classes=10, qkv_bias=False,
        depths=4, sr_ratios=1,
        T=args.timestep,
        act_func=LIF)

# Load pretrained weights if available
model_fp = os.path.join("/home/cxz760/selfsupervised_SpikingNN/barlowtwins_SNN/save/cifar10/resnet34_spk_1l_mixlif/",
                        "checkpoint_{}.tar".format(160))
model_dict = spiking_resnet.state_dict()
pretrained_dict = torch.load(model_fp, map_location=args.device.type)
new_dict = {}
for k, v in pretrained_dict['model_state_dict'].items():
    # print(k)
    if "backbone." in k:
        key = k.replace("backbone.", "")
        new_dict[key] = v
spiking_resnet.load_state_dict(new_dict, strict=False)

model_fp_2 = os.path.join("/home/cxz760/selfsupervised_SpikingNN/spikformer_barlow_twins/cifar10/save/ts6/20250218-195315-spikformer_barlow_twins-32/",
                        "model_best.pth.tar")  # model_best.pth.tar 949
model_dict = spikformer.state_dict()
pretrained_dict_2 = torch.load(model_fp_2, map_location=args.device.type)
new_dict_2 = {}
for k, v in pretrained_dict_2['state_dict'].items():
    # print(k)
    if "backbone." in k:
        # print(k, v)
        key = k.replace("backbone.", "")
        new_dict_2[key] = v
spikformer.load_state_dict(new_dict_2, strict=False)

spiking_resnet.to(args.device).eval()
spikformer.to(args.device).eval()


resnet_spk = []
spikformer_spk = []

def register_hooks(model, layer_key, in_list, out_list):
    def hook(module, fea_in, fea_out):
        in_list.append(fea_in[0].detach().cpu())
        out_list.append(fea_out.detach().cpu())
    for name, module in model.named_modules():
        if layer_key in name:
            print(name)
            module.register_forward_hook(hook)

# Prepare containers
features_in_resnet = []
features_out_resnet = []
features_in_spikformer = []
features_out_spikformer = []

# Register hooks on spike_func layers
print("spiking_resnet lif layers:")
register_hooks(spiking_resnet, 'spike_func', features_in_resnet, features_out_resnet)
print("================")
print("spikformer lif layers:")
register_hooks(spikformer, 'lif', features_in_spikformer, features_out_spikformer)

# Single batch forward to collect features
with torch.no_grad():
    for (x, _), _ in train_loader:
        x = x.to(args.device)
        _ = spiking_resnet(x)      # collects in features_out_resnet
        _ = spikformer(x)          # collects in features_out_spikformer
        break  # only first batch

def compute_spike_rate(features_out, args):
    spike_rates = []  # list of lists per time step
    # features_out is list of tensors [time*batch, ...]
    for fea in features_out:
        batch_T = fea.shape[0]  # T * batch_size
        # reshape to (T, batch, ...)
        fea = fea.view(args.timestep, args.logistic_batch_size, -1)
        rates = []
        for t in range(args.timestep):
            frame = fea[t]  # (batch, D)
            # proportion of spikes: count non-zero / total
            spikes = (frame != 0).float().sum()
            total = frame.numel()
            rates.append((spikes / total).item())
        spike_rates.append(rates)
    return np.array(spike_rates)

# Compute spike rates
sr_resnet = compute_spike_rate(features_out_resnet, args)
sr_spikformer = compute_spike_rate(features_out_spikformer, args)

# Display results
print("Spiking-ResNet34 spike rates per layer (rows) and time step (cols):")
print(sr_resnet, f'//Avg spk rate for resnet: {sr_resnet.mean()}')

print("Spikformer-4-384 spike rates per layer (rows) and time step (cols):")
print(sr_spikformer, f'//Avg spk rate for spikformer: {sr_spikformer.mean()}')


# For FLOPs, create a dummy input for a single time step
dummy_input = torch.randn(1, 3, 32, 32, device=args.device)
spiking_resnet = get_resnet('resnet34', num_classes=10)
spiking_resnet.to(args.device).eval()

flops_resnet = FlopCountAnalysis(spiking_resnet, dummy_input)
flops_spikformer = FlopCountAnalysis(spikformer, dummy_input)

# print("FLOPs for Spiking-ResNet34 backbone (per frame):")
# print(flop_count_table(flops_resnet, max_depth=4))
#
# print("\nFLOPs for Spikformer-4-384 backbone (per frame):")
# print(flop_count_table(flops_spikformer, max_depth=4))

# Multiply by number of time steps to approximate per-sequence cost
total_flops_resnet = flops_resnet.total() * 1
total_flops_spik = flops_spikformer.total() * 1

print(f"\nApproximate total FLOPs over {args.timestep} timesteps:")
print(f"Spiking-ResNet34: {total_flops_resnet:.2e} FLOPs")
print(f"Spikformer-4-384: {total_flops_spik:.2e} FLOPs")