from typing import Literal
import torch.nn as nn
import torch.fx as fx
import torch
import copy
from ._utils import recursively_find_named_children
from transformers.utils.fx import symbolic_trace as transformer_symbolic_trace
from transformers import PreTrainedModel


def _fold_batch_norm(conv_or_fc_layer, bn_layer):
    """Folds the batch normalization parameters into the convolutional or linear layer.

    Args:
        conv_or_fc_layer (nn.Module): The convolutional or linear layer to fold BN into.
        bn_layer (nn.BatchNorm2d or nn.BatchNorm1d): The batch normalization layer to be folded.
    """
    with torch.no_grad():
        if isinstance(conv_or_fc_layer, nn.Conv2d):
            w = conv_or_fc_layer.weight
            b = (
                conv_or_fc_layer.bias
                if conv_or_fc_layer.bias is not None
                else torch.zeros(w.size(0))
            )

            bn_mean = bn_layer.running_mean
            bn_var = bn_layer.running_var
            bn_weight = bn_layer.weight
            bn_bias = bn_layer.bias
            bn_eps = bn_layer.eps

            bn_std = torch.sqrt(bn_var + bn_eps)
            scale = bn_weight / bn_std
            bias = bn_bias - bn_mean * bn_weight / bn_std

            w = w * scale.view(-1, 1, 1, 1)
            b = b * scale + bias

            conv_or_fc_layer.weight.copy_(w)
            conv_or_fc_layer.bias = nn.Parameter(b)

        elif isinstance(conv_or_fc_layer, nn.Linear):
            w = conv_or_fc_layer.weight
            b = (
                conv_or_fc_layer.bias
                if conv_or_fc_layer.bias is not None
                else torch.zeros(w.size(0))
            )

            bn_mean = bn_layer.running_mean
            bn_var = bn_layer.running_var
            bn_weight = bn_layer.weight
            bn_bias = bn_layer.bias
            bn_eps = bn_layer.eps

            bn_std = torch.sqrt(bn_var + bn_eps)
            scale = bn_weight / bn_std
            bias = bn_bias - bn_mean * bn_weight / bn_std

            w = w * scale.view(-1, 1)
            b = b * scale + bias

            conv_or_fc_layer.weight.copy_(w)
            conv_or_fc_layer.bias = nn.Parameter(b)


def batch_norm_folding(model: nn.Module | fx.GraphModule):
    """Folds all batch normalization layers in the model into their preceding layers.

    Args:
        model (nn.Module): The model containing batch normalization layers to be folded.

    Returns:
        fx.GraphModule: The modified model with batch normalization layers folded.
    """
    if not isinstance(model, fx.GraphModule):
        model = copy.deepcopy(model)
        graph_module = fx.symbolic_trace(model)
    else:
        graph_module = model
    modules = dict(graph_module.named_modules())

    for node in graph_module.graph.nodes:
        if node.op == "call_module":
            layer = modules[node.target]
            if isinstance(layer, nn.BatchNorm2d) or isinstance(layer, nn.BatchNorm1d):
                prev_node = node.args[0]
                if prev_node.op == "call_module":
                    prev_layer = modules[prev_node.target]
                    if isinstance(prev_layer, (nn.Conv2d, nn.Linear)):
                        _fold_batch_norm(prev_layer, layer)
                        with graph_module.graph.inserting_after(prev_node):
                            node.replace_all_uses_with(prev_node)
                            graph_module.graph.erase_node(node)

    def _dummy_forward(*args, **kwargs):
        raise ValueError(
            "Batch Norm has been folded, this module should not be called."
        )

    graph_module.recompile()
    for n, m in recursively_find_named_children(graph_module):
        if isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
            m.running_mean = None
            m.running_var = None
            m.forward = _dummy_forward

    return graph_module


def extract_normalisation_gammas(
    model: nn.Module | PreTrainedModel, kind: Literal["torch", "transformer"] = "torch"
) -> dict[str, torch.Tensor]:
    """Extracts the gamma values from the batch normalization layers in the model.

    It holds that the layer-wise Hessian of a layer that is immediately followed by
    batchnorm / layernorm is given by modifying the layerwise Hessian H_0 by

        H = 2 gamma^2 / (var + eps) H_0
    """

    # class LayerNormTracer(fx.Tracer):
    #     def is_leaf_module(self, module, *args, **kwargs):
    #         # Treat LayerNorm as a leaf module to avoid tracing its internals
    #         if isinstance(module, (nn.LayerNorm, nn.BatchNorm2d)):
    #             return True
    #         return super().is_leaf_module(module, *args, **kwargs)

    # # Trace the model to get its GraphModule
    # tracer = LayerNormTracer()
    # if kind == "transformer":
    #     graph_module = TransformerGraphModule(model, tracer.trace(model))
    # else:
    #     graph_module = fx.GraphModule(model, tracer.trace(model))
    if kind == "transformer":
        assert isinstance(model, PreTrainedModel)
        graph_module = transformer_symbolic_trace(model)
    else:
        graph_module = fx.symbolic_trace(model)

    gammas = {}

    for node in graph_module.graph.nodes:
        if node.op == "call_module":
            target_module = dict(model.named_modules())[node.target]
            if isinstance(target_module, nn.BatchNorm1d):
                print("Found BatchNorm1d, not implemented. Skipping.")
            if isinstance(target_module, (nn.LayerNorm, nn.BatchNorm2d)):
                gamma = target_module.weight
                if isinstance(target_module, nn.BatchNorm2d):
                    assert target_module.running_var is not None
                    gamma = gamma / torch.sqrt(
                        target_module.running_var + target_module.eps
                    )

                # Find the preceding node's name
                preceding_node = None
                for input_node in node.all_input_nodes:
                    if input_node.op in ["call_module", "call_function", "call_method"]:
                        preceding_node = input_node
                        break

                if preceding_node is not None:
                    gammas[preceding_node.name] = gamma
    for n, m in recursively_find_named_children(model):
        if n not in gammas:
            gammas[n] = 1
    return gammas
