# import torch
# from torchinfo import summary  # pip install torchinfo
# from ptflops import get_model_complexity_info  # pip install ptflops
#
# def count_parameters(model):
#     total = sum(p.numel() for p in model.parameters())
#     trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
#     return total, trainable
#
# def get_model_size_mb(model):
#     import io
#     import os
#     torch.save(model.state_dict(), "temp.p")
#     size_mb = os.path.getsize("temp.p") / 1e6
#     os.remove("temp.p")
#     return size_mb
#
# def print_model_statistics(model, input_res, device='cpu'):
#     model = model.to(device)
#     dummy_input = torch.randn(*input_res).to(device)
#
#     total_params, trainable_params = count_parameters(model)
#     size_mb = get_model_size_mb(model)
#
#     print("🚀 Model Summary:")
#     print(f"Total Parameters       : {total_params:,}")
#     print(f"Trainable Parameters   : {trainable_params:,}")
#     print(f"Model Size             : {size_mb:.2f} MB")
#
#     # FLOPs (forward pass complexity)
#     try:
#         macs, params = get_model_complexity_info(
#             model, input_res[1:], as_strings=True, print_per_layer_stat=False, verbose=False
#         )
#         print(f"FLOPs (approx MACs)    : {macs}")
#     except Exception as e:
#         print("FLOP Estimation Error:", e)


import torch
import torch.nn as nn
from torchinfo import summary
import numpy as np



def estimate_graph_flops(model, num_nodes, num_edges, in_features):
    """
    Estimate FLOPs for graph-based models
    """
    total_flops = 0

    for name, module in model.named_modules():
        if isinstance(module, type(model)):  # Skip the root model
            continue

        if hasattr(module, 'in_features') and hasattr(module, 'out_features'):
            # This is likely your RBM layer
            layer_flops = manual_flop_count_rbm_layer(module, num_nodes, num_edges)
            total_flops += layer_flops
            print(f"Layer {name}: {layer_flops:,} FLOPs")

        elif isinstance(module, nn.Linear):
            # Standard linear layer
            layer_flops = num_nodes * module.in_features * module.out_features
            total_flops += layer_flops
            print(f"Linear layer {name}: {layer_flops:,} FLOPs")

    return total_flops


def print_model_statistics_fixed(model, num_nodes, in_features, num_edges=None, device='cpu'):
    """Fixed version for graph models"""

    # Basic statistics
    total_params, trainable_params = count_parameters(model)
    size_mb = get_model_size_mb(model)

    print("🚀 Model Summary:")
    print(f"Total Parameters       : {total_params:,}")
    print(f"Trainable Parameters   : {trainable_params:,}")
    print(f"Model Size             : {size_mb:.2f} MB")

    # Manual FLOP estimation for your RBM layer
    if num_edges is None:
        num_edges = num_nodes * 4  # Assume average degree of 4

    total_flops = estimate_graph_flops(model, num_nodes, num_edges, in_features)
    print(f"Estimated FLOPs        : {total_flops:,}")
    print(f"FLOPs (in GFLOPs)      : {total_flops / 1e9:.2f}")




def count_parameters(model):
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total, trainable


def get_model_size_mb(model):
    import io
    import os
    torch.save(model.state_dict(), "temp.p")
    size_mb = os.path.getsize("temp.p") / 1e6
    os.remove("temp.p")
    return size_mb


def manual_flop_count_rbm_layer(layer, num_nodes, num_edges):
    """
    Manual FLOP counting for RBM Contrastive Divergence Layer
    Based on your actual forward implementation
    """
    flops = 0

    # 1. Dropout (if training) - negligible FLOPs, just masking
    if layer.training and hasattr(layer, 'dropout_layer'):
        flops += num_nodes * layer.in_features  # Masking operations

    # 2. Neighbor aggregation using edge_index only
    # - Creating sparse adjacency: O(num_edges) for indexing
    flops += num_edges

    # - Adding self-connections: O(num_nodes)
    flops += num_nodes

    # - Degree computation: O(num_edges)
    flops += num_edges

    # - Normalization: O(num_edges)
    flops += num_edges

    # - Sparse matrix multiplication for aggregation: O(num_edges * in_features)
    flops += num_edges * layer.in_features

    # 3. Combining input and aggregated features
    # combined_input = self.residual * x + self.neighbor_weight * aggregated_x
    # - Two scalar multiplications: O(num_nodes * in_features * 2)
    # - One addition: O(num_nodes * in_features)
    flops += num_nodes * layer.in_features * 3

    # 4. RBM Forward Pass
    # hidden_activation = self.hidden_bias + torch.matmul(visible, self.W)
    # - Matrix multiplication: O(num_nodes * in_features * out_features)
    flops += num_nodes * layer.in_features * layer.out_features

    # - Bias addition: O(num_nodes * out_features)
    flops += num_nodes * layer.out_features

    # - Sampling activation (sigmoid + sampling): O(num_nodes * out_features)
    flops += num_nodes * layer.out_features * 2  # sigmoid + sampling

    # 5. Training vs Inference path
    if layer.training:
        # Contrastive Divergence with k-step Gibbs sampling
        for step in range(layer.k_steps):
            # Forward pass in Gibbs
            # hidden_activation = bias + matmul(visible, W)
            flops += num_nodes * layer.in_features * layer.out_features
            flops += num_nodes * layer.out_features  # bias
            flops += num_nodes * layer.out_features  # sampling

            # Backward pass in Gibbs
            # visible_activation = bias + matmul(hidden, W.T)
            flops += num_nodes * layer.out_features * layer.in_features
            flops += num_nodes * layer.in_features  # bias
            flops += num_nodes * layer.in_features  # sampling

        # Final forward pass for negative phase
        flops += num_nodes * layer.in_features * layer.out_features
        flops += num_nodes * layer.out_features
        flops += num_nodes * layer.out_features

    return flops


def estimate_full_model_flops(model, num_nodes, num_edges, in_features):
    """
    Estimate FLOPs for your complete model based on the forward method
    """
    total_flops = 0
    current_features = in_features

    print(f"📊 FLOP Analysis for {num_nodes} nodes, {num_edges} edges:")
    print("-" * 50)

    # Iterate through layers (excluding the last one)
    for i, layer in enumerate(model.layers[:-1]):
        if hasattr(layer, 'in_features') and hasattr(layer, 'out_features'):
            # RBM layer
            layer_flops = manual_flop_count_rbm_layer(layer, num_nodes, num_edges)
            total_flops += layer_flops
            print(f"RBM Layer {i + 1}: {layer_flops:,} FLOPs")
            current_features = layer.out_features
        elif isinstance(layer, nn.Linear):
            # Regular linear layer
            layer_flops = num_nodes * layer.in_features * layer.out_features
            total_flops += layer_flops
            print(f"Linear Layer {i + 1}: {layer_flops:,} FLOPs")
            current_features = layer.out_features

        # ReLU activation: O(num_nodes * current_features)
        relu_flops = num_nodes * current_features
        total_flops += relu_flops
        print(f"  └─ ReLU activation: {relu_flops:,} FLOPs")

    # Output layer (last layer)
    if len(model.layers) > 0:
        last_layer = model.layers[-1]
        if hasattr(last_layer, 'in_features') and hasattr(last_layer, 'out_features'):
            layer_flops = manual_flop_count_rbm_layer(last_layer, num_nodes, num_edges)
            total_flops += layer_flops
            print(f"Output RBM Layer: {layer_flops:,} FLOPs")
        elif isinstance(last_layer, nn.Linear):
            layer_flops = num_nodes * last_layer.in_features * last_layer.out_features
            total_flops += layer_flops
            print(f"Output Linear Layer: {layer_flops:,} FLOPs")

        # Log softmax: O(num_nodes * output_features)
        softmax_flops = num_nodes * last_layer.out_features
        total_flops += softmax_flops
        print(f"  └─ Log Softmax: {softmax_flops:,} FLOPs")

    print("-" * 50)
    print(f"🎯 Total FLOPs: {total_flops:,}")

    return total_flops


def create_mock_data(num_nodes, in_features, num_edges=None):
    """
    Create mock graph data for testing
    """
    if num_edges is None:
        num_edges = num_nodes * 4  # Assume average degree of 4

    # Create a simple Data-like object
    class MockData:
        def __init__(self, x, edge_index):
            self.x = x
            self.edge_index = edge_index

    # Node features
    x = torch.randn(num_nodes, in_features)

    # Edge indices (random graph)
    edge_index = torch.randint(0, num_nodes, (2, num_edges))

    return MockData(x, edge_index)


def print_model_statistics_graph(model, num_nodes, in_features, num_edges=None, device='cpu'):
    """
    Print statistics for graph-based models
    """
    model = model.to(device)

    if num_edges is None:
        num_edges = num_nodes * 4  # Assume average degree of 4

    # Create mock data
    mock_data = create_mock_data(num_nodes, in_features, num_edges)
    mock_data.x = mock_data.x.to(device)
    mock_data.edge_index = mock_data.edge_index.to(device)

    total_params, trainable_params = count_parameters(model)
    size_mb = get_model_size_mb(model)

    print("🚀 Graph Model Summary:")
    print(f"Total Parameters       : {total_params:,}")
    print(f"Trainable Parameters   : {trainable_params:,}")
    print(f"Model Size             : {size_mb:.2f} MB")
    print(f"Graph Info             : {num_nodes} nodes, {num_edges} edges")

    # Estimate FLOPs based on your actual forward method
    try:
        total_flops = estimate_full_model_flops(model, num_nodes, num_edges, in_features)
        print(f"Total Estimated FLOPs  : {total_flops:,}")
        print(f"FLOPs (in GFLOPs)      : {total_flops / 1e9:.2f}")
    except Exception as e:
        print("FLOP Estimation Error:", e)

    # Memory usage estimation
    try:
        with torch.no_grad():
            model.eval()
            output = model(mock_data)
            memory_mb = torch.cuda.memory_allocated() / 1e6 if torch.cuda.is_available() else 0
            print(f"GPU Memory (approx)    : {memory_mb:.2f} MB")
            print(f"Output shape           : {output.shape}")
    except Exception as e:
        print("Forward pass error:", e)


def benchmark_model_speed(model, num_nodes, in_features, num_edges=None, device='cpu', num_runs=100):
    """
    Benchmark model inference speed
    """
    import time

    model = model.to(device).eval()
    mock_data = create_mock_data(num_nodes, in_features, num_edges)
    mock_data.x = mock_data.x.to(device)
    mock_data.edge_index = mock_data.edge_index.to(device)

    # Warmup
    with torch.no_grad():
        for _ in range(10):
            _ = model(mock_data)

    # Benchmark
    torch.cuda.synchronize() if torch.cuda.is_available() else None
    start_time = time.time()

    with torch.no_grad():
        for _ in range(num_runs):
            _ = model(mock_data)

    torch.cuda.synchronize() if torch.cuda.is_available() else None
    end_time = time.time()

    avg_time = (end_time - start_time) / num_runs
    print(f"⏱️  Average inference time: {avg_time * 1000:.2f} ms")
    print(f"   Throughput: {1 / avg_time:.2f} inferences/sec")


# Usage example for your RBM model:
def analyze_rbm_model(model, num_nodes=1000, in_features=128, num_edges=None):
    """
    Complete analysis of your RBM model
    """
    print("=" * 60)
    print("RBM CONTRASTIVE DIVERGENCE MODEL ANALYSIS")
    print("=" * 60)

    # Model statistics
    print_model_statistics_graph(model, num_nodes, in_features, num_edges)

    print("\n" + "-" * 40)

    # Speed benchmark
    benchmark_model_speed(model, num_nodes, in_features, num_edges)

    print("=" * 60)


# Example usage:
if __name__ == "__main__":
    # Assuming you have your RBM model
    # model = YourRBMModel(...)

    # Analyze the model
    # analyze_rbm_model(model, num_nodes=1000, in_features=128)

    # Or just get statistics
    # print_model_statistics_graph(model, num_nodes=1000, in_features=128)

    print("FLOP counter functions created successfully!")
    print("Use: analyze_rbm_model(your_model, num_nodes, in_features)")



