from typing import Dict, Any, Tuple

import torch
from torch import nn

from quantization.quantization_flows.gptq import gptq_quantization
from quantization.quantization_flows.rtn import rtn_quantization
from run_config import RunConfig
from enums import PTQAlg
from quantization.quant_config import QuantConfig
from transform_optimization.opt_config import OptimizationConfig


def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''):
    if type(module) in layers:
        return {name: module}
    res = {}
    for name1, child in module.named_children():
        res.update(find_layers(
            child, layers=layers, name=name + '.' + name1 if name != '' else name1
        ))
    return res


def quantize_model(
        model: nn.Module,
        run_config: RunConfig,
        quant_config: QuantConfig,
        opt_config: OptimizationConfig = None,
        calibration_data: torch.utils.data.DataLoader = None,
) -> Tuple[nn.Module, Dict[str, Any]]:
    """Replace Linear layers with MX quantized versions."""
    if quant_config.ptq_alg == PTQAlg.GPTQ:
        quant_fn = gptq_quantization
    elif quant_config.ptq_alg == PTQAlg.RTN:
        quant_fn = rtn_quantization
    else:
        # Previously - microsoft MX path
        raise ValueError(f"Unsupported quantization alg: {quant_config.ptq_alg}")

    quant_info, non_quantized_state_dict = quant_fn(
        model,
        calibration_data,
        quant_config,
        run_config,
        opt_config)
    return model, quant_info
