import re
import gc
import argparse
from typing import List

import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM

from .qlinear import QLinear
from .quantizer import Quantizer
from .quant_ops import pack_fp4_to_uint8, cast_scales_to_eXmY
from ..transforms.transforms import build_transform, get_transform_matrix
from ..utils.common_utils import to, maybe_first_element
from ..utils.model_utils import InputCollector, ForwardInterrupt, get_attention_layer, get_mlp_layer

try:
    import wandb
except ImportError:
    wandb = None


def blockwise_qat_quantization(
    model: AutoModelForCausalLM, 
    calibration_data: List[torch.Tensor],
    args: argparse.Namespace, 
    device: torch.device
) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
    print("Blockwise QAT quantization...")
    orig_dtype = model.config.torch_dtype if args.dtype == "auto" else args.dtype
    act_offload_device = "cpu" if args.cpu_offload_activations else device
    # State dict with quantized weights, scales and hadamards
    quantized_state_dict = {}
    non_quantized_state_dict = {}
    # Define common transform kwargs
    transform_kwargs = dict(
        init=args.init, 
        parametrization=args.parametrization, 
        device=device,
        group_size=args.hadamard_group_size,
        rank=args.lora_rank
    )
    # Init quantizer kwargs
    weight_quantizer_kwargs = None
    if args.w_bits < 16:
        weight_quantizer_kwargs = dict(
            bits=args.w_bits,
            symmetric=True, 
            format=args.format,
            granularity=args.w_granularity,
            observer=args.w_observer, 
            group_size=args.w_group_size,
            scale_precision=args.scale_precision
        )
    act_quantizer_kwargs = None
    if args.a_bits < 16:
        act_quantizer_kwargs = dict(
            bits=args.a_bits, 
            symmetric=True, 
            format=args.format,
            granularity=args.a_granularity,
            observer=args.a_observer, 
            group_size=args.a_group_size,
            scale_precision=args.scale_precision
        )

    blocks = model.model.layers
    blocks[0] = InputCollector(blocks[0], cpu_offload=args.cpu_offload_activations)
    if args.cpu_offload_modules:
        model.get_input_embeddings().to(device)
        blocks[0] = blocks[0].to(device)

    for sample in calibration_data:
        try:
            with torch.no_grad():
                model(sample.to(device=device))
        except ForwardInterrupt:
            pass
        
    input_args = blocks[0].input_args
    input_kwargs = blocks[0].input_kwargs
    blocks[0] = blocks[0].module

    if args.cpu_offload_modules:
        model.get_input_embeddings().cpu()

    # Iterate over transformer blocks
    for block_idx, block in enumerate(blocks):
        print(f"Processing block {block_idx}...")
        if args.cpu_offload_modules:
            block.to(device)
         # 1. Init targets
        targets = []
        # Collect original model outputs
        for inp_args, inp_kwargs in zip(input_args, input_kwargs):
            with torch.no_grad():
                out = block(*to(inp_args, device=device), **to(inp_kwargs, device=device))
            out = maybe_first_element(out)
            targets.append(out.float().to(act_offload_device))

        # 2. Init transforms
        qkv_in_transform = build_transform(args.transform_class, size=model.config.hidden_size, **transform_kwargs)
        o_in_transform = build_transform(args.transform_class, size=model.config.hidden_size, **transform_kwargs)
        gate_up_in_transform = build_transform(args.transform_class, size=model.config.hidden_size, **transform_kwargs)
        down_in_transform = build_transform(args.transform_class, size=model.config.intermediate_size, **transform_kwargs)     

        # 3. Replace blocks with quantized versions
        quantized_attn = get_attention_layer(model.config)(
            model.config,
            layer_idx=block_idx,
            weight_quantizer_kwargs=weight_quantizer_kwargs,
            act_quantizer_kwargs=act_quantizer_kwargs,
            qkv_in_transform=qkv_in_transform,
            o_in_transform=o_in_transform
        )
        quantized_mlp = get_mlp_layer(model.config)(
            model.config,
            weight_quantizer_kwargs=weight_quantizer_kwargs,
            act_quantizer_kwargs=act_quantizer_kwargs,
            gate_up_in_transform=gate_up_in_transform,
            down_in_transform=down_in_transform
        )

        quantized_attn.load_state_dict(block.self_attn.state_dict(), strict=False)
        quantized_mlp.load_state_dict(block.mlp.state_dict(), strict=False)

        block.self_attn = quantized_attn
        block.mlp = quantized_mlp

        # Cast block to float (to be compatible with trainable transforms)
        block = block.float()
        # Make sure that all params are on the same device
        block.to(device)
        # Unlock gradient for all parameters
        if args.train_original_parameters:
            block.requires_grad_(True)

        optimizer = torch.optim.Adam(block.parameters(), lr=args.lr)
        scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer,
            step_size=1,
            gamma=args.lr_decay,
        )
        scaler = torch.amp.GradScaler(enabled=args.amp)

        for epoch in range(args.epochs):
            print(f"Epoch {epoch}")
            # Prepare batch ids
            batch_ids = torch.randperm(len(input_args)).tolist()

            train_loss = 0
            for idx in batch_ids:
                with torch.amp.autocast(device_type="cuda", enabled=args.amp):
                    out = block(*to(input_args[idx], device=device), **to(input_kwargs[idx], device=device))
                    out = maybe_first_element(out)
                    loss = F.mse_loss(out, targets[idx].to(device))
                optimizer.zero_grad()
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
                train_loss += loss.item() / len(input_args)

                if args.fuse_global_scale:
                    # qkv fusion
                    qkv_global_scale = min(
                        quantized_attn.q_proj.weight_quantizer.global_scale,
                        quantized_attn.k_proj.weight_quantizer.global_scale,
                        quantized_attn.v_proj.weight_quantizer.global_scale,
                    )
                    quantized_attn.q_proj.weight_quantizer.global_scale = qkv_global_scale
                    quantized_attn.k_proj.weight_quantizer.global_scale = qkv_global_scale
                    quantized_attn.v_proj.weight_quantizer.global_scale = qkv_global_scale
                    # gate_up fusion
                    gate_up_global_scale = min(
                        quantized_mlp.gate_proj.weight_quantizer.global_scale,
                        quantized_mlp.up_proj.weight_quantizer.global_scale
                    )
                    quantized_mlp.gate_proj.weight_quantizer.global_scale = gate_up_global_scale
                    quantized_mlp.up_proj.weight_quantizer.global_scale = gate_up_global_scale

            scheduler.step()
            print(f"Loss: {train_loss:.2e}")
            if args.log_wandb:
                wandb.log({f"train/block{block_idx}/loss": train_loss})

        # Disable scale tracking
        for layer_name, layer in block.named_modules():
            if isinstance(layer, QLinear):
                if layer.weight_quantizer:
                    layer.weight_quantizer._track_global_scale = False
                if layer.act_quantizer:
                    layer.act_quantizer._track_global_scale = False

        # Delete optimizer
        del optimizer
        del scheduler
        del scaler
        gc.collect()
        torch.cuda.empty_cache()

         # Save stuff (if real_quant)
        if args.real_quant:
            non_quantized_block_state_dict = block.state_dict()
            for layer_name, layer in block.named_modules():
                if isinstance(layer, QLinear):
                    with torch.no_grad():
                        if re.search("(q|k|v)_proj", layer_name):
                            layer_transform = qkv_in_transform
                        elif re.search("o_proj", layer_name):
                            layer_transform = o_in_transform
                        elif re.search("(gate|up)_proj", layer_name):
                            layer_transform = gate_up_in_transform
                        else:
                            layer_transform = down_in_transform

                        weight = layer_transform(layer.weight, inv_t=True)
                        scales, zeros = layer.weight_quantizer.get_quantization_params(weight)
                        qweight = layer.weight_quantizer.quantize(weight, scales, zeros)

                    weight_global_scale = layer.weight_quantizer.global_scale.to(scales.device)
                    act_global_scale = layer.act_quantizer.global_scale

                    quantized_state_dict[f"model.layers.{block_idx}.{layer_name}"] = {
                        "qweight": pack_fp4_to_uint8(qweight).cpu(),
                        "scales": cast_scales_to_eXmY(scales * weight_global_scale, args.scale_precision).cpu(),
                        "forward_hadamard_matrix": get_transform_matrix(args.transform_class, args.hadamard_group_size, device, orig_dtype).cpu(),
                        "backward_hadamard_matrix": get_transform_matrix(args.transform_class, args.hadamard_group_size, device, orig_dtype).cpu(),
                        "weight_global_scale": weight_global_scale.clone(),
                        "act_global_scale": act_global_scale.clone()
                    }

                    # Pop weight from quantized_state_dict
                    del non_quantized_block_state_dict[f"{layer_name}.weight"]
            # Add all remaining stuff to non_quantized_state_dict
            non_quantized_state_dict.update(
                {f"model.layers.{block_idx}.{k}": v.cpu() for k, v in non_quantized_block_state_dict.items()}
            )

        # 3. Fix model parametrization
        quantized_attn.fix_parametrization()
        quantized_mlp.fix_parametrization()
        # 4. Fix transforms and remove parametrizations
        qkv_in_transform.remove_parametrizations()
        o_in_transform.remove_parametrizations()
        gate_up_in_transform.remove_parametrizations()
        down_in_transform.remove_parametrizations() 

        # Cast to original dtype
        block = block.to(orig_dtype)

        for inp_args, inp_kwargs in zip(input_args, input_kwargs):
            with torch.no_grad():
                out = block(*to(inp_args, device=device), **to(inp_kwargs, device=device))
            out = maybe_first_element(out).to(act_offload_device)
            # change only first input argument
            if len(inp_args) > 0:
                inp_args[0].data = out
            elif "hidden_states" in inp_kwargs:
                inp_kwargs["hidden_states"] = out
            else:
                raise ValueError("Unsupported block input format.")
            
        if args.cpu_offload_modules:
            block.cpu()

        # Empty cache to free GPU memory
        gc.collect()
        torch.cuda.empty_cache()   

    gc.collect()
    torch.cuda.empty_cache()

    return quantized_state_dict, non_quantized_state_dict
