# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import warnings
from argparse import Namespace
from collections import defaultdict
from copy import deepcopy
from typing import Any, Dict, Iterator, List, Sequence, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parallel.scatter_gather import Gather

from typing import Any, Callable, Iterable, Iterator, List, Optional, Sequence, Union

import torch
import torch.nn.functional as F
import transformers


def replace_parameter_(module: nn.Module, name: str, new_value: torch.Tensor):
    """A hacky way to substitute an already registered parameter with a non-parameter tensor. Breaks future use."""
    if name in module._parameters:
        module._parameters[name] = new_value
    else:
        setattr(module, name, new_value)


def iterate_minibatches(
    *tensors: torch.Tensor,
    batch_size: int,
    allow_incomplete: bool = True,
    device: Optional[torch.device] = None,
    callback: Callable[[Sequence[torch.Tensor]], Sequence[torch.Tensor]] = lambda x: x,
) -> Iterator[Sequence[torch.Tensor]]:
    """
    Samples data points *forever*, in random order, with less overhead than DataLoader;
    Adapted from https://github.com/stanis-morozov/unq/blob/master/lib/utils.py
    probably implemented many times in transformers, torch, etc.

    Args:
        tensors: one or more tensors with the same 0-th dimension
        batch_size: sample this many points with each yield
        allow_incomplete: if True and dataset size is not divisible by batch size,
            the last batch may have fewer than batch_size samples to cover entire dataset.
            If False, last batch is dropped.
        callback: optional function called on each batch before yielding

    Returns:
        Generator yielding minibatches of tensors, matching the length of input tensors.
        If batch contains only one tensor, yields that tensor directly (not wrapped in tuple).
    """
    num_samples = len(tensors[0])
    assert all(len(x) == num_samples for x in tensors)
    indices = torch.randperm(num_samples, device=tensors[0].device)
    while True:
        prev_batch = None
        for batch_start in range(0, len(indices), batch_size):
            if not allow_incomplete and batch_start + batch_size > len(indices):
                break
            batch_ix = indices[batch_start : batch_start + batch_size]
            batch = callback(tuple(tensor[batch_ix].to(device, non_blocking=True) for tensor in tensors))
            if prev_batch is not None:
                yield prev_batch
            # Yield the current batch on next iteration to reduce overhead
            prev_batch = batch if (isinstance(batch, (list, tuple)) and len(tensors) > 1) else batch[0]
            del batch
        yield prev_batch


@torch.enable_grad()
def update_groupwise(
    *,
    layer: nn.Module,
    train_inps: Sequence[torch.Tensor],
    train_outs: Sequence[torch.Tensor],
    args: Namespace,
    valid_inps: Sequence[torch.Tensor] = None,
    valid_outs: Sequence[torch.Tensor] = None,
    verbose: bool = True,
    **kwargs,
) -> Tuple[nn.Module, bool]:
    """
    update a module with pre-quantized linear layers to minimize mean squared error between layer inputs and outputs.
     Note: This function is for parameter updating, not fine-tuning.
    Args:
        layer: Trainable module, whose linear layers are replaced by QuantizedLinear instances.
        train_inps: List of input activation tensors, shape [nsamples_per_device, seq_len, hidden_size].
        train_outs: List of output activation tensors, shape [nsamples_per_device, seq_len, hidden_size].
        args: Namespace of quantization hyperparameters from main.py.
        valid_inps: Optional list of validation input activation tensors.
        valid_outs: Optional list of validation output activation tensors.
        verbose: Whether to print detailed information.
        kwargs: Additional keyword arguments passed to each forward pass.

    Returns:
        Tuple containing the updated module and a boolean flag indicating if training improved the model.
    """

    is_better_flag = False  # Flag indicating if validation improved

    # Check input and output tensors for each device
    print(f"Checking input and output tensors on each device")
    for i in range(len(args.devices)):
        print("Checking types:", type(train_inps[i]), type(train_outs[i]))

        assert isinstance(train_inps[i], torch.Tensor) and isinstance(train_outs[i], torch.Tensor)
        if not args.offload_activations:
            assert train_inps[i].device == train_outs[i].device == args.devices[i], (
                train_inps[i].device,
                train_outs[i].device,
                args.devices,
            )
        else:
            assert train_inps[i].device == train_outs[i].device == torch.device("cpu")
            assert train_inps[i].is_pinned() and train_outs[i].is_pinned()

    # Replicate non-trainable parameters to each GPU if multiple devices
    replicas = kwargs_by_device = None
    if len(args.devices) > 1:
        # Replicate model across multiple devices
        replicas = torch.nn.parallel.replicate(layer, args.devices)
        replicas[0] = layer
        kwargs_by_device = []
        for device in args.devices:
            kwargs_by_device.append(
                {k: (v.to(device, non_blocking=True) if isinstance(v, torch.Tensor) else v) for k, v in kwargs.items()}
            )

    # Initialize trainable parameters on main device; prepare to distribute to replicas
    differentiable_parameters_by_name = {name: param for name, param in layer.named_parameters() if param.requires_grad}
    param_names, differentiable_parameters = zip(*differentiable_parameters_by_name.items())
    differentiable_parameters = nn.ParameterList(differentiable_parameters)

    # Initialize gradients to zero
    for param in differentiable_parameters:
        param.grad = torch.zeros_like(param)

    if replicas:
        replacement_tables = _make_parameter_replacement_tables(layer, replicas, param_names, differentiable_parameters)

    print(f"update {sum(param.numel() for param in differentiable_parameters)} parameters")

    # Setup optimizer
    opt = torch.optim.Adam(
        differentiable_parameters, lr=args.update_lr, betas=(args.update_adam_beta1, args.update_adam_beta2)
    )

    # Validate update batch size divisibility
    assert args.update_batch_size % len(args.devices) == 0, "batch_size must be divisible by number of GPUs"

    num_samples_per_device = train_inps.shape[1]  # Number of samples per device
    print(f"Number of samples per device: {num_samples_per_device}")
    local_batch_size = args.local_batch_size
    args.update_batch_size = local_batch_size
    if local_batch_size is None:
        local_batch_size = args.update_batch_size // len(args.devices)

    assert all(len(inps_tensor) == num_samples_per_device for inps_tensor in train_inps)

    # Compute number of accumulation steps
    num_accumulation_steps = args.update_batch_size // (local_batch_size * len(args.devices))
    assert num_samples_per_device % (local_batch_size * num_accumulation_steps) == 0, (
        num_samples_per_device,
        local_batch_size,
    )

    train_batches_per_epoch = num_samples_per_device // local_batch_size
    print(f"Training with {len(args.devices)} devices, {num_samples_per_device} samples per device, ")
    print(f"batch size {local_batch_size}, {train_batches_per_epoch} batches per epoch, {num_accumulation_steps} accumulation steps")

    train_batch_iterators = [
        iterate_minibatches(train_inps[i], train_outs[i], batch_size=local_batch_size, device=args.devices[i])
        for i in range(len(args.devices))
    ]
    inputs, targets = next(train_batch_iterators[0])

    run_validation = False
    if valid_inps is not None and valid_outs is not None and valid_inps.any() and valid_outs.any():
        run_validation = True
        num_valid_samples_per_device = len(valid_inps[0])
        valid_batches_per_epoch = num_valid_samples_per_device // local_batch_size
        valid_batch_iterators = [
            iterate_minibatches(valid_inps[i], valid_outs[i], batch_size=local_batch_size, device=args.devices[i])
            for i in range(len(args.devices))
        ]

    if run_validation:
        # Evaluate before training
        print(f"Evaluating before training")
        layer.eval()
        loss_numerator = loss_denominator = 0
        with torch.no_grad():
            for _ in range(valid_batches_per_epoch):
                if len(args.devices) == 1:
                    loss = _compute_mse_on_batch(layer, valid_batch_iterators[0], **kwargs)
                else:
                    loss = _compute_mse_parallel(
                        args.devices,
                        replicas,
                        differentiable_parameters,
                        replacement_tables,
                        valid_batch_iterators,
                        kwargs_by_device,
                    )
                loss_numerator += loss.item()
                loss_denominator += 1
        valid_loss_epoch = loss_numerator / loss_denominator
        print(f"Evaluation before training.")
        print(f"Valid loss={valid_loss_epoch:.2e}\t")
        best_loss = valid_loss_epoch
        best_parameters_by_name = deepcopy(differentiable_parameters_by_name)
        worse_count = 0

    steps_accumulated = 0
    print(f"Starting training for {args.update_max_epochs} epochs")
    for epoch in range(args.update_max_epochs):
        layer.train()
        print(f"Training epoch {epoch + 1}/{args.update_max_epochs}")

        loss_numerator = loss_denominator = 0
        for _ in range(train_batches_per_epoch):
            if len(args.devices) == 1:
                loss = _compute_mse_on_batch(layer, train_batch_iterators[0], **kwargs)
            else:
                loss = _compute_mse_parallel(
                    args.devices,
                    replicas,
                    differentiable_parameters,
                    replacement_tables,
                    train_batch_iterators,
                    kwargs_by_device,
                )

            scaled_loss = loss / num_accumulation_steps
            scaled_loss.backward()
            steps_accumulated += 1

            if not torch.isfinite(loss).item():
                raise ValueError(f"update loss is {loss}")

            if steps_accumulated >= num_accumulation_steps:
                opt.step()
                opt.zero_grad()
                steps_accumulated = 0

            loss_numerator += loss.item()
            loss_denominator += 1
        train_loss_epoch = loss_numerator / loss_denominator

        if run_validation:
            layer.eval()
            loss_numerator = loss_denominator = 0
            with torch.no_grad():
                for _ in range(valid_batches_per_epoch):
                    if len(args.devices) == 1:
                        loss = _compute_mse_on_batch(layer, valid_batch_iterators[0], **kwargs)
                    else:
                        loss = _compute_mse_parallel(
                            args.devices,
                            replicas,
                            differentiable_parameters,
                            replacement_tables,
                            valid_batch_iterators,
                            kwargs_by_device,
                        )
                    loss_numerator += loss.item()
                    loss_denominator += 1
            valid_loss_epoch = loss_numerator / loss_denominator

        # Log losses at the end of epoch
        if verbose:
            print("-" * 10)
            print(f"epoch={epoch}")
            print(f"train loss={train_loss_epoch:.2e}\t")
            if epoch == 0:
                initial_loss = train_loss_epoch
            if epoch != 0 and train_loss_epoch < initial_loss:
                is_better_flag = True
            if run_validation:
                print(f"valid loss={valid_loss_epoch:.2e}\t")

        if run_validation:
            if valid_loss_epoch < best_loss:
                print(f"New best loss {valid_loss_epoch:.2e} on epoch {epoch}")
                best_loss = valid_loss_epoch
                best_parameters_by_name = deepcopy(differentiable_parameters_by_name)
                worse_count = 0
            else:
                worse_count += 1
                if worse_count >= args.update_early_stop:
                    break  # Early stopping

    if run_validation:
        layer.load_state_dict(best_parameters_by_name, strict=False)  # Restore best parameters

    return layer, is_better_flag  # Return updated model and improvement flag


@torch.enable_grad()
def update_groupwise_true(
    *,
    layer: nn.Module,
    train_inps: Sequence[torch.Tensor],
    train_outs: Sequence[torch.Tensor],
    args: Namespace,
    valid_inps: Sequence[torch.Tensor] = None,
    valid_outs: Sequence[torch.Tensor] = None,
    verbose: bool = True,
    **kwargs,
) -> Tuple[nn.Module, bool]:
    """
    update a module with pre-quantized linear layers to minimize mean squared error between layer inputs and outputs.
     Note: This function is for parameter updating, not fine-tuning.
    Args:
        layer: A trainable module whose linear layers are replaced with QuantizedLinear instances.
        train_inps: List of input activation tensors, shape [nsamples_per_device, seq_len, hidden_size].
        train_outs: List of output activation tensors, shape [nsamples_per_device, seq_len, hidden_size].
        args: Quantization hyperparameters from main.py.
        valid_inps: Optional list of validation input activation tensors.
        valid_outs: Optional list of validation output activation tensors.
        verbose: Whether to print detailed information.
        kwargs: Additional keyword arguments passed to each forward pass.

    Returns:
        Tuple containing the updated module and a boolean flag indicating if validation improved the model.
    """

    is_better_flag = False  # Flag indicating if validation improved

    # Check input and output tensors on each device
    print(f"Checking input and output tensors on each device")
    for i in range(len(args.devices)):
        print("Checking types:", type(train_inps[i]), type(train_outs[i]))

        assert isinstance(train_inps[i], torch.Tensor) and isinstance(train_outs[i], torch.Tensor)
        if not args.offload_activations:
            assert train_inps[i].device == train_outs[i].device == args.devices[i], (
                train_inps[i].device,
                train_outs[i].device,
                args.devices,
            )
        else:
            assert train_inps[i].device == train_outs[i].device == torch.device("cpu")
            assert train_inps[i].is_pinned() and train_outs[i].is_pinned()

    # Replicate model to each GPU if multiple devices
    replicas = kwargs_by_device = None
    if len(args.devices) > 1:
        replicas = torch.nn.parallel.replicate(layer, args.devices)
        replicas[0] = layer
        kwargs_by_device = []
        for device in args.devices:
            kwargs_by_device.append(
                {k: (v.to(device, non_blocking=True) if isinstance(v, torch.Tensor) else v) for k, v in kwargs.items()}
            )

    # Initialize trainable parameters on main device and prepare for replicas
    differentiable_parameters_by_name = {name: param for name, param in layer.named_parameters() if param.requires_grad}
    param_names, differentiable_parameters = zip(*differentiable_parameters_by_name.items())
    differentiable_parameters = nn.ParameterList(differentiable_parameters)

    # Initialize gradients to zero
    for param in differentiable_parameters:
        param.grad = torch.zeros_like(param)

    if replicas:
        replacement_tables = _make_parameter_replacement_tables(layer, replicas, param_names, differentiable_parameters)

    print(f"update {sum(param.numel() for param in differentiable_parameters)} parameters")

    # Setup optimizer
    opt = torch.optim.Adam(
        differentiable_parameters, lr=args.update_lr, betas=(args.update_adam_beta1, args.update_adam_beta2)
    )

    # Check if update batch size is divisible by number of GPUs
    assert args.update_batch_size % len(args.devices) == 0, "batch_size must be divisible by the number of GPUs"

    train_inps = train_inps.unsqueeze(0)  # Add dimension at dim=0
    train_outs = train_outs.unsqueeze(0)
    print("train_inps shape:", train_inps.shape)  # Shape should be [1, nsamples_per_device, seq_len, hidden_size]

    num_samples_per_device = train_inps.shape[1]  # Number of samples per device
    print(f"Number of samples per device: {num_samples_per_device}")

    local_batch_size = args.local_batch_size
    if local_batch_size is None:
        local_batch_size = args.update_batch_size // len(args.devices)

    assert all(len(inps_tensor) == num_samples_per_device for inps_tensor in train_inps)
    assert args.update_batch_size % (local_batch_size * len(args.devices)) == 0, ""

    num_accumulation_steps = args.update_batch_size // (local_batch_size * len(args.devices))
    assert num_samples_per_device % (local_batch_size * num_accumulation_steps) == 0, (
        num_samples_per_device,
        local_batch_size,
    )

    train_batches_per_epoch = num_samples_per_device // local_batch_size
    print(f"Training with {len(args.devices)} devices, {num_samples_per_device} samples per device, ")
    print(f"batch size {local_batch_size}, {train_batches_per_epoch} batches per epoch, {num_accumulation_steps} accumulation steps")

    train_batch_iterators = [
        iterate_minibatches(train_inps[i], train_outs[i], batch_size=local_batch_size, device=args.devices[i])
        for i in range(len(args.devices))
    ]

    inputs, targets = next(train_batch_iterators[0])
    print("Inputs shape:", inputs.shape)
    print("Targets shape:", targets.shape)
    print("Inputs sample:", inputs[0])  # Print first sample
    print("Targets sample:", targets[0])

    run_validation = False
    valid_inps = valid_inps.unsqueeze(0)
    valid_outs = valid_outs.unsqueeze(0)
    if valid_inps.any() and valid_outs.any():
        run_validation = True
        num_valid_samples_per_device = len(valid_inps[0])
        valid_batches_per_epoch = num_valid_samples_per_device // local_batch_size
        valid_batch_iterators = [
            iterate_minibatches(valid_inps[i], valid_outs[i], batch_size=local_batch_size, device=args.devices[i])
            for i in range(len(args.devices))
        ]

    if run_validation:
        # Evaluation before training
        print(f"Evaluating before training")
        layer.eval()
        loss_numerator = loss_denominator = 0
        with torch.no_grad():
            for _ in range(valid_batches_per_epoch):
                if len(args.devices) == 1:
                    loss = _compute_nll_on_batch(layer, valid_batch_iterators[0], **kwargs)
                else:
                    loss = _compute_mse_parallel(
                        args.devices,
                        replicas,
                        differentiable_parameters,
                        replacement_tables,
                        valid_batch_iterators,
                        kwargs_by_device,
                    )
                loss_numerator += loss.item()
                loss_denominator += 1
        valid_loss_epoch = loss_numerator / loss_denominator
        print(f"Evaluation before training.")
        print(f"Valid loss={valid_loss_epoch:.2e}\t")
        best_loss = valid_loss_epoch
        best_parameters_by_name = deepcopy(differentiable_parameters_by_name)
        worse_count = 0

    steps_accumulated = 0
    print(f"Starting training for {args.update_max_epochs} epochs")
    for epoch in range(args.update_max_epochs):
        layer.train()

        print(f"Training epoch {epoch + 1}/{args.update_max_epochs}")

        loss_numerator = loss_denominator = 0
        for _ in range(train_batches_per_epoch):
            if len(args.devices) == 1:
                loss = _compute_nll_on_batch(layer, train_batch_iterators[0], **kwargs)
            else:
                loss = _compute_mse_parallel(
                    args.devices,
                    replicas,
                    differentiable_parameters,
                    replacement_tables,
                    train_batch_iterators,
                    kwargs_by_device,
                )

            scaled_loss = loss / num_accumulation_steps
            scaled_loss.backward()
            steps_accumulated += 1

            if not torch.isfinite(loss).item():
                raise ValueError(f"update loss is {loss}")

            if steps_accumulated >= num_accumulation_steps:
                opt.step()
                opt.zero_grad()
                steps_accumulated = 0

            loss_numerator += loss.item()
            loss_denominator += 1
        train_loss_epoch = loss_numerator / loss_denominator

        if run_validation:
            layer.eval()
            loss_numerator = loss_denominator = 0
            with torch.no_grad():
                for _ in range(valid_batches_per_epoch):
                    if len(args.devices) == 1:
                        loss = _compute_nll_on_batch(layer, valid_batch_iterators[0], **kwargs)
                    else:
                        loss = _compute_mse_parallel(
                            args.devices,
                            replicas,
                            differentiable_parameters,
                            replacement_tables,
                            valid_batch_iterators,
                            kwargs_by_device,
                        )
                    loss_numerator += loss.item()
                    loss_denominator += 1
            valid_loss_epoch = loss_numerator / loss_denominator

        # Log losses at epoch end
        if verbose:
            print("-" * 10)
            print(f"epoch={epoch}")
            print(f"train loss={train_loss_epoch:.2e}\t")
            if epoch == 0:
                initial_loss = train_loss_epoch
            if epoch != 0 and train_loss_epoch < initial_loss:
                is_better_flag = True
            if run_validation:
                print(f"valid loss={valid_loss_epoch:.2e}\t")

        if run_validation:
            if valid_loss_epoch < best_loss:
                print(f"New best loss {valid_loss_epoch:.2e} on epoch {epoch}")
                best_loss = valid_loss_epoch
                best_parameters_by_name = deepcopy(differentiable_parameters_by_name)
                worse_count = 0
            else:
                worse_count += 1
                if worse_count >= args.update_early_stop:
                    break  # Early stopping

    if run_validation:
        layer.load_state_dict(best_parameters_by_name, strict=False)  # Restore best parameters

    return layer, is_better_flag  # Return updated model and improvement flag



def _make_parameter_replacement_tables(
    layer: nn.Module, replicas: Sequence[nn.Module], param_names: Sequence[str], parameters: nn.ParameterList
) -> Sequence[List[Sequence[Tuple[nn.Module, str]]]]:
    """
    Prepare auxiliary data structures for quickly copying parameters to replicas for data-parallel training.
    """
    assert len(param_names) == len(parameters)
    assert len(replicas) > 1
    assert replicas[0] is layer

    parameters_by_name = dict(zip(param_names, parameters))

    param_to_name = {param: name for name, param in parameters_by_name.items()}
    param_occurences = defaultdict(list)  # param_name -> List [ Tuple [submodule name, attr name] ]
    for submodule_name, submodule in layer.named_modules():
        for attr_name, param in submodule.named_parameters(recurse=False):  # immediate params only
            if param in param_to_name:
                param_name = param_to_name[param]
                param_occurences[param_name].append((submodule_name, attr_name))
    assert len(param_occurences) == len(parameters), "internal error: not all parameters were found"

    replacement_tables = []
    for replica in replicas:
        replacement_table = list()  # for each master param -> List[ Tuple[replica submodule, attr name] ]
        replica_modules_by_name: Dict[str, nn.Module] = dict(replica.named_modules())

        for param_name, master_param in zip(param_names, parameters):
            param_replacements = list()
            for submodule_name, attr_name in param_occurences[param_name]:
                param_replacements.append((replica_modules_by_name[submodule_name], attr_name))
            replacement_table.append(param_replacements)
        replacement_tables.append(replacement_table)
    return replacement_tables



def _compute_mse_on_batch(
    layer: nn.Module, batch_iter: Iterator[Tuple[torch.Tensor, torch.Tensor]], **kwargs
) -> torch.Tensor:
    inps_batch, outs_batch = next(batch_iter)

    outs_prediction = layer(inps_batch, **kwargs)

    assert outs_prediction.shape == outs_batch.shape
    loss = F.mse_loss(outs_prediction, outs_batch)

    return loss


def _compute_kl_on_batch(
    layer: nn.Module, batch_iter: Iterator[Tuple[torch.Tensor, torch.Tensor]], **kwargs
) -> torch.Tensor:
    inps_batch, outs_batch = next(batch_iter)

    if inps_batch.shape[0] != 1:
        for name, value in list(kwargs.items()):
            if isinstance(value, torch.Tensor) and value.shape[0] == 1:
                if name not in ("attention_mask", "position_ids"):
                    warnings.warn(f"Tiling an unexpected kwarg {name} over batch size; make sure this is valid.")
                repeats = [len(inps_batch)] + [1 for _ in range(value.ndim - 1)]
                kwargs[name] = value.tile(*repeats)

    outs_prediction = layer(inps_batch, **kwargs)
    log_probs_pred = F.log_softmax(outs_prediction, dim=-1)

    probs_target = F.softmax(outs_batch, dim=-1)

    log_probs_pred = log_probs_pred.float()
    probs_target = probs_target.float()

    loss = F.kl_div(log_probs_pred, probs_target, reduction='batchmean')

    return loss.to(torch.float16)


def _compute_cross_on_batch(
    layer: nn.Module, batch_iter: Iterator[Tuple[torch.Tensor, torch.Tensor]], **kwargs
) -> torch.Tensor:
    inps_batch, outs_batch = next(batch_iter)

    if inps_batch.shape[0] != 1:
        for name, value in list(kwargs.items()):
            if isinstance(value, torch.Tensor) and value.shape[0] == 1:
                if name not in ("attention_mask", "position_ids"):
                    warnings.warn(f"Tiling an unexpected kwarg {name} over batch size; make sure this is valid.")
                repeats = [len(inps_batch)] + [1 for _ in range(value.ndim - 1)]
                kwargs[name] = value.tile(*repeats)

    outs_prediction = layer(inps_batch, **kwargs)

    labels = outs_batch.argmax(dim=-1)

    logits = outs_prediction.permute(0, 2, 1)

    loss = F.cross_entropy(logits, labels)

    print(f"loss: {loss.item()}")
    return loss


def _compute_nll_on_batch(
    layer: nn.Module, batch_iter: Iterator[Tuple[torch.Tensor, torch.Tensor]], **kwargs
) -> torch.Tensor:
    inps_batch, outs_batch = next(batch_iter)

    log_probs_out = outs_batch.argmax(dim=2)
    print("log_probs_out shape:", log_probs_out.shape)

    if inps_batch.shape[0] != 1:
        for name, value in list(kwargs.items()):
            if isinstance(value, torch.Tensor) and value.shape[0] == 1:
                if name not in ("attention_mask", "position_ids"):
                    warnings.warn(f"Tiling an unexpected kwarg {name} over batch size; make sure this is valid.")
                repeats = [len(inps_batch)] + [1 for _ in range(value.ndim - 1)]
                kwargs[name] = value.tile(*repeats)

    outs_prediction, *_unused = layer(inps_batch, **kwargs)

    outs_prediction = outs_prediction.unsqueeze(0)

    w = outs_batch - outs_prediction
    w = w.float()

    mean_w = w.mean().item()

    threshold = torch.quantile(w, 0.99).item()

    top_1_percent_values = w[w >= threshold]

    top_1_percent_values_sorted, _ = torch.sort(top_1_percent_values, descending=True)

    print("Mean of w:", mean_w)
    print("Top 1% largest values in w (sorted descending):")
    print(top_1_percent_values_sorted)

    log_probs = F.log_softmax(outs_prediction, dim=2)
    log_probs = log_probs.permute(0, 2, 1)

    loss = F.nll_loss(log_probs, log_probs_out)
    print(f"{loss.item()}")

    return loss



def _compute_mse_parallel(
    devices: Sequence[torch.device],
    replicas: Sequence[nn.Module],
    parameters_to_replicate: nn.ParameterList,
    replacement_tables: Sequence[List[Sequence[Tuple[nn.Module, str]]]],
    batch_iterators: Sequence[Iterator[Tuple[torch.Tensor, torch.Tensor]]],
    kwargs_by_device: Sequence[Dict[str, Any]],
) -> torch.Tensor:
    """Compute MSE in parallel over multiple GPUs, each GPU processes a portion of samples"""
    replicated_parameters = torch.nn.parallel.replicate(parameters_to_replicate, devices, detach=False)
    funcs_by_replica = [_compute_mse_on_batch for _ in replicas]
    inputs_by_replica = []
    for i in range(len(devices)):
        if i != 0:  # no overrides needed for master module
            for replacement_param, replacement_table in zip(replicated_parameters[i], replacement_tables[i]):
                for (replica_submodule, attr_name) in replacement_table:
                    replace_parameter_(replica_submodule, attr_name, replacement_param)
        inputs_by_replica.append((replicas[i], batch_iterators[i]))
    mse_components = torch.nn.parallel.parallel_apply(
        funcs_by_replica, inputs_by_replica, kwargs_by_device, devices=devices
    )
    return Gather.apply(devices[0], 0, *(mse.view(1) for mse in mse_components)).mean()


def _compute_mse_parallel(
        devices: Sequence[torch.device],
        replicas: Sequence[nn.Module],
        parameters_to_replicate: nn.ParameterList,
        replacement_tables: Sequence[List[Sequence[Tuple[nn.Module, str]]]],
        batch_iterators: Sequence[Iterator[Tuple[torch.Tensor, torch.Tensor]]],
        kwargs_by_device: Sequence[Dict[str, Any]],
) -> torch.Tensor:
    """Compute MSE in parallel over multiple GPUs, each GPU processes a portion of samples"""
    replicated_parameters = torch.nn.parallel.replicate(parameters_to_replicate, devices, detach=False)
    funcs_by_replica = [_compute_mse_on_batch for _ in replicas]
    inputs_by_replica = []
    for i in range(len(devices)):
        if i != 0:  # no overrides needed for master module
            for replacement_param, replacement_table in zip(replicated_parameters[i], replacement_tables[i]):
                for (replica_submodule, attr_name) in replacement_table:
                    replace_parameter_(replica_submodule, attr_name, replacement_param)
        inputs_by_replica.append((replicas[i], batch_iterators[i]))
    mse_components = torch.nn.parallel.parallel_apply(
        funcs_by_replica, inputs_by_replica, kwargs_by_device, devices=devices
    )
    return Gather.apply(devices[0], 0, *(mse.view(1) for mse in mse_components)).mean()


def _compute_outliers_on_batch(
        layer: nn.Module, batch_iter: Iterator[Tuple[torch.Tensor, torch.Tensor]], **kwargs
) -> torch.Tensor:
    inps_batch, outs_batch = next(batch_iter)  # outs_batch: [1, seq_len, hidden_size]

    # Remove batch dimension
    inps_batch = inps_batch.squeeze(0)  # [seq_len, input_dim]
    outs_batch = outs_batch.squeeze(0)  # [seq_len, hidden_size]

    # Get prediction output
    outs_prediction, *_unused = layer(inps_batch, **kwargs)  # [seq_len, hidden_size]

    # Calculate absolute error
    error = (outs_batch - outs_prediction).float().abs()  # [seq_len, hidden_size]

    # Find top N error elements; N = 14336*6
    num_outliers = 14336 * 6
    error_flat = error.view(-1)
    topk_values, topk_indices = torch.topk(error_flat, num_outliers)

    seq_len, hidden_size = error.shape
    topk_rows, topk_cols = torch.unravel_index(topk_indices.cpu(), (seq_len, hidden_size))  # i,j indices

    # Find max feature index in input corresponding to topk rows
    weight_cols = torch.argmax(inps_batch[topk_rows, :], dim=1)  # [num_outliers]

    # Final mapping to weight matrix (W^T) coordinates: (topk_cols, weight_cols)
    # topk_cols -> hidden dimension (W row index)
    # weight_cols -> input dimension (W column index)

    device = topk_cols.device  # or specify device = torch.device("cuda:0")

    weight_cols = weight_cols.to(device)
    topk_cols = topk_cols.to(device)

    weight_indices = torch.stack([topk_cols, weight_cols], dim=1)  # [num_outliers, 2]

    return weight_indices


@torch.enable_grad()
def compensation_groupwise(
        *,
        layer: nn.Module,
        train_inps: Sequence[torch.Tensor],
        train_outs: Sequence[torch.Tensor],
        args: Namespace,
        valid_inps: Sequence[torch.Tensor] = None,
        valid_outs: Sequence[torch.Tensor] = None,
        verbose: bool = True,
        **kwargs,
) -> torch.Tensor:
    """
    update a module with pre-quantized linear layers to minimize mean squared error between layer inputs and outputs.
    Note: This function is for parameter updating, not fine-tuning.
    Args:
        layer: A trainable module whose linear layers are replaced with QuantizedLinear instances.
        train_inps: List of input activation tensors, shape [nsamples_per_device, seq_len, hidden_size].
        train_outs: List of output activation tensors, shape [nsamples_per_device, seq_len, hidden_size].
        args: Quantization hyperparameters from main.py.
        valid_inps: Optional list of validation input activation tensors.
        valid_outs: Optional list of validation output activation tensors.
        verbose: Whether to print detailed information.
        kwargs: Additional keyword arguments passed to each forward pass.

    Returns:
        Tensor containing the most frequent weight indices identified by analyzing the largest errors.
    """

    local_batch_size = args.local_batch_size
    if local_batch_size is None:
        local_batch_size = args.update_batch_size // len(args.devices)
    run_validation = False

    valid_inps = valid_inps.unsqueeze(0)
    valid_outs = valid_outs.unsqueeze(0)
    if valid_inps.any() and valid_outs.any():
        run_validation = True
        num_valid_samples_per_device = len(valid_inps[0])
        valid_batches_per_epoch = num_valid_samples_per_device // local_batch_size
        valid_batch_iterators = [
            iterate_minibatches(valid_inps[i], valid_outs[i], batch_size=local_batch_size, device=args.devices[i])
            for i in range(len(args.devices))
        ]

    if run_validation:
        print(f"Evaluating before training")
        layer.eval()
        weight_indices_list = []  # To store weight indices from each sample

        # Collect weight indices from each batch
        with torch.no_grad():
            for _ in range(valid_batches_per_epoch):
                if len(args.devices) == 1:
                    weight_indices = _compute_outliers_on_batch(layer, valid_batch_iterators[0], **kwargs)
                    print("weight_indices shape", weight_indices.shape)
                    weight_indices_list.append(weight_indices)

        print(f"Evaluation before training.")

        # Concatenate all weight indices into a 2D tensor: [num_batches * num_outliers, 2]
        all_indices = torch.cat(weight_indices_list, dim=0)  # [valid_batches_per_epoch * num_outliers, 2]

        # To count frequency, convert 2D coords to 1D index (hash)
        # Assumes weight matrix shape is (hidden_size, input_dim), get these dims
        hidden_size, input_dim = layer.original_shape

        # Convert (row, col) coords to linear indices
        linear_indices = all_indices[:, 0] * input_dim + all_indices[:, 1]

        # Count occurrence of each index
        unique_indices, counts = torch.unique(linear_indices, return_counts=True)

        # Select top N (14336*6) indices by frequency
        topk_counts, topk_indices = torch.topk(counts, 14336 * 6)

        # Corresponding linear indices
        topk_linear_indices = unique_indices[topk_indices]

        # Convert linear indices back to 2D coords (row, col)
        topk_rows = topk_linear_indices // input_dim
        topk_cols = topk_linear_indices % input_dim

        topk_weight_indices = torch.stack([topk_rows, topk_cols], dim=1)  # [N, 2]

        print(f"Top {14336 * 6} frequent weight indices shape: {topk_weight_indices.shape}")
        # These are the most frequent 2D coordinates, can be further used as needed

        print(f"Evaluation before training.")

        return topk_weight_indices  # Return the most frequent weight indices

    # If no validation data, return empty tensor
    return torch.empty(0, 2, dtype=torch.long)


