from itertools import repeat
import collections.abc

import torch
from torch import nn as nn
from torchvision.ops.misc import FrozenBatchNorm2d

"""
This module contains utility functions for handling PyTorch model layers.
Main features include:
- Converting model's `BatchNorm2d` and `SyncBatchNorm` layers to `FrozenBatchNorm2d` layers.
- Replacing linear layers in a model with specified linear layer implementations.
- Computing top-k accuracy for model predictions.
"""

def freeze_batch_norm_2d(module, module_match={}, name=''):
    """
    Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is
    itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and
    returned. Otherwise, the module is walked recursively and submodules are converted in place.

    Args:
        module (torch.nn.Module): Any PyTorch module.
        module_match (dict): Dictionary of full module names to freeze (all if empty)
        name (str): Full module name (prefix)

    Returns:
        torch.nn.Module: Resulting module

    """
    res = module
    is_match = True
    if module_match:
        is_match = name in module_match
    if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)):
        res = FrozenBatchNorm2d(module.num_features)
        res.num_features = module.num_features
        res.affine = module.affine
        if module.affine:
            res.weight.data = module.weight.data.clone().detach()
            res.bias.data = module.bias.data.clone().detach()
        res.running_mean.data = module.running_mean.data
        res.running_var.data = module.running_var.data
        res.eps = module.eps
    else:
        for child_name, child in module.named_children():
            full_child_name = '.'.join([name, child_name]) if name else child_name
            new_child = freeze_batch_norm_2d(child, module_match, full_child_name)
            if new_child is not child:
                res.add_module(child_name, new_child)
    return res


# From PyTorch internals
def _ntuple(n):
    """
    Create a function that parses input into n-tuples.
    
    Args:
        n (int): Number of repeated elements in tuple.
        
    Returns:
        function: A function that converts input into n-tuples.
    """
    def parse(x):
        if isinstance(x, collections.abc.Iterable):
            return x
        return tuple(repeat(x, n))
    return parse


to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)
to_3tuple = _ntuple(3)
to_4tuple = _ntuple(4)
to_ntuple = lambda n, x: _ntuple(n)(x)

def replace_linear(model, linear_replacement, include_modules=['c_fc', 'c_proj'], copy_weights=True):
    """
    Replace all linear layers in the model with a specified linear layer implementation.
    
    Recursively walks through the model and replaces linear layers whose names are in
    the include_modules list with instances of linear_replacement.
    
    Args:
        model (torch.nn.Module): The model to modify.
        linear_replacement (class): The linear layer implementation to use as replacement.
        include_modules (list): List of module names to include in replacement.
        copy_weights (bool): Whether to copy weights from the original linear layers.
        
    Returns:
        torch.nn.Module: The modified model.
    """
    for name, module in model.named_children():
        if len(list(module.children())) > 0:
            replace_linear(module, linear_replacement, include_modules, copy_weights)

        if isinstance(module, torch.nn.Linear) and name in include_modules:
            old_module = model._modules[name]
            model._modules[name] = linear_replacement(
                module.in_features,
                module.out_features,
                module.bias is not None,
            )
            if copy_weights:
                model._modules[name].weight.data.copy_(old_module.weight.data)
                if model._modules[name].bias is not None:
                    model._modules[name].bias.data.copy_(old_module.bias)

    return model

def convert_int8_model_to_inference_mode(model):
    """
    Convert an int8 model to inference mode.
    
    Calls prepare_for_eval() on all modules that have this method,
    which is typically used to prepare quantized models for evaluation.
    
    Args:
        model (torch.nn.Module): The int8 model to convert to inference mode.
    """
    for m in model.modules():
        if hasattr(m, 'prepare_for_eval'):
            int8_original_dtype = m.weight.dtype
            m.prepare_for_eval()
            m.int8_original_dtype = int8_original_dtype
            

def accuracy(output, target, topk=(1,)):
    """
    Compute top-k accuracy.
    
    Args:
        output (torch.Tensor): Model output logits with shape (N, C) where N is the number 
                            of examples and C is the number of classes.
        target (torch.Tensor): Ground truth class indices with shape (N,) where N is the 
                            number of examples.
        topk (tuple): Which top-k accuracies to compute, e.g., (1,5) will compute 
                    top-1 and top-5 accuracies.
    
    Returns:
        list: List of top-k accuracies in the same order as `topk`.
    """
    pred = output.topk(max(topk), 1, True, True)[1].t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))
    n = len(target)
    return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) / n for k in topk]
