import math
from typing import Union, Tuple
import torch
import torch.nn as nn
from torch.utils.flop_counter import FlopCounterMode
from modules.complexPyTorch.complexLayers import ComplexConv2d, ComplexLinear


def count_model_params(model: nn.Module):
    """
    Counts number of parameters in a model.
    """

    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)

    return n_parameters

def count_model_layers(model: nn.Module):
    """
    Counts number of conv layers in a model.
    """

    counts = {
        'dwconv': 0,
        'conv': 0,
        '1x1': 0,
        'complex_conv': 0,
        'complex_1x1': 0,
    }

    # recurse through the model and all layers of submodules
    def count_layers(module):
        for layer in module.children():
            if layer.__class__.__name__ == 'DWConv2d':
                counts['dwconv'] += 1
            elif isinstance(layer, nn.Conv2d):
                if layer.kernel_size[0] > 1:
                    counts['conv'] += 1
                else:
                    counts['1x1'] += 1
            elif isinstance(layer, nn.Linear):
                counts['1x1'] += 1
            elif isinstance(layer, ComplexConv2d):
                counts['complex_conv'] += 1
            elif isinstance(layer, ComplexLinear):
                counts['complex_1x1'] += 1
            #count_layers(layer)
            else:
                count_layers(layer)

    count_layers(model)
    counts['total'] = sum(counts.values())

    return counts

def count_flops(model, inp: Union[torch.Tensor, Tuple], with_backward=False, device='cpu'):
    """
    Measures the number of FLOPs in a model.

    Based on the function from https://alessiodevoto.github.io/Compute-Flops-with-Pytorch-built-in-flops-counter/
    """

    istrain = model.training
    model.eval()

    inp = inp if isinstance(inp, torch.Tensor) else torch.randn(inp)
    inp = inp.to(device)

    flop_counter = FlopCounterMode(mods=model, display=False, depth=None)
    with flop_counter:
        if with_backward:
            model(inp).sum().backward()
        else:
            model(inp)
    total_flops = flop_counter.get_total_flops()
    if istrain:
        model.train()
    return total_flops
