"""
Slice Training Implementation for Efficient Fine-tuning of Large Pretrained Models

This module implements slice training, a parameter-efficient fine-tuning technique that selectively
trains only a subset (slice) of weight matrix rows or columns while keeping the rest frozen.

Based on the paper: "Rocoft: Efficient finetuning of large language models with row-column updates"
https://arxiv.org/abs/2410.10075
The RoCoFT is statics training, which means the slice position will not change during training. Our solution is to make the slice position change during training.

Key Features:
- Dynamic row/column updates with configurable scheduling
- Support for both row-wise and column-wise slicing modes
- Alternating row-column training for comprehensive coverage
- Memory-efficient implementation with proper cleanup

Components:
- SliceLinear: Custom nn.Module that partitions linear layers into trainable and frozen sections
- inject_peft: Replaces standard nn.Linear layers with SliceLinear layers
- setup_slice_training: Configures slice training parameters for each training phase
- merge_slice_linear: Converts SliceLinear layers back to standard nn.Linear
- restore_layers: Batch restoration of all SliceLinear layers in a model

Usage:
The slice training process involves:
1. Injecting SliceLinear layers into the model via inject_peft()
2. Training with only the sliced parameters being updated
3. Optionally changing slice positions/modes between training phases
4. Restoring to standard layers via restore_layers() when complete

"""

import gc
import logging
from collections import deque
from typing import Tuple

import torch
import torch.nn as nn

log = logging.getLogger(__name__)


class SliceLinear(nn.Module):
    """
    A parameter-efficient linear layer that trains only a slice of the weight matrix.

    This layer partitions a linear layer's weight matrix into three parts:
    - part_A: Frozen parameters (before trainable slice)
    - part_T: Trainable parameters (the slice being updated)
    - part_B: Frozen parameters (after trainable slice)

    Args:
        F: Original nn.Linear layer to be sliced
        rank: Width of the trainable slice
        position: Starting position of the trainable slice
        bias: Whether to make bias parameters trainable
        mode: Slicing direction ('column' for input features, 'row' for output features)
    """

    def __init__(
        self,
        F: nn.Linear,
        rank: int = 1,
        position: int = 0,
        bias: bool = True,
        mode: str = "column",
    ):
        super().__init__()
        if mode not in ("column", "row"):
            raise ValueError(f"mode must be 'column' or 'row', got '{mode}'")
        if rank <= 0:
            raise ValueError(f"rank must be positive, got {rank}")

        F.eval()
        self.mode = mode
        self.original_shape = F.weight.shape

        # Store metadata for debugging and inspection
        self.rank = rank
        self.position = position

        if mode == "column":
            self._init_column_mode(F, rank, position, bias)
        else:  # mode == "row"
            self._init_row_mode(F, rank, position, bias)

        # Initialize bias
        if F.bias is not None:
            self.bias = nn.Parameter(F.bias.detach().clone(), requires_grad=bias)
        else:
            self.bias = None

    def _init_column_mode(self, F: nn.Linear, rank: int, position: int, bias: bool):
        """Initialize column-wise slicing (slice input features)."""
        total_weights = F.weight.shape[1]

        # Adjust position if it would exceed bounds
        if position + rank > total_weights:
            position = max(0, total_weights - rank)
            log.warning(
                f"Adjusted position to {position} to fit rank {rank} in {total_weights} columns"
            )

        self.position = position  # Update with adjusted position
        self.axis = 1

        # Partition weight matrix by columns
        self.part_A = nn.Parameter(
            F.weight[:, :position].detach().clone(), requires_grad=False
        )
        self.part_T = nn.Parameter(
            F.weight[:, position : position + rank].detach().clone(), requires_grad=True
        )
        self.part_B = nn.Parameter(
            F.weight[:, position + rank :].detach().clone(), requires_grad=False
        )

        # Cache slice boundaries for efficient forward pass
        self.a_end = self.part_A.shape[1]
        self.t_end = self.a_end + self.part_T.shape[1]

    def _init_row_mode(self, F: nn.Linear, rank: int, position: int, bias: bool):
        """Initialize row-wise slicing (slice output features)."""
        total_weights = F.weight.shape[0]

        # Adjust position if it would exceed bounds
        if position + rank > total_weights:
            position = max(0, total_weights - rank)
            log.warning(
                f"Adjusted position to {position} to fit rank {rank} in {total_weights} rows"
            )

        self.position = position  # Update with adjusted position
        self.axis = 0

        # Partition weight matrix by rows
        self.part_A = nn.Parameter(
            F.weight[:position, :].detach().clone(), requires_grad=False
        )
        self.part_T = nn.Parameter(
            F.weight[position : position + rank, :].detach().clone(), requires_grad=True
        )
        self.part_B = nn.Parameter(
            F.weight[position + rank :, :].detach().clone(), requires_grad=False
        )

    @property
    def weight(self):
        return torch.cat([self.part_A, self.part_T, self.part_B], dim=self.axis)

    @property
    def part_A_weight(self):
        return self.part_A.data

    @property
    def part_B_weight(self):
        return self.part_B.data

    @property
    def part_T_weight(self):
        return self.part_T.data

    def forward(self, x):
        if self.mode == "column":  # Column slicing
            return (
                torch.nn.functional.linear(x[..., : self.a_end], self.part_A)
                + torch.nn.functional.linear(
                    x[..., self.a_end : self.t_end], self.part_T, self.bias
                )
                + torch.nn.functional.linear(x[..., self.t_end :], self.part_B)
            )
        else:
            x = torch.cat(
                [
                    torch.nn.functional.linear(x, self.part_A),
                    torch.nn.functional.linear(x, self.part_T),
                    torch.nn.functional.linear(x, self.part_B),
                ],
                dim=-1,
            )
            if self.bias is not None:
                x += self.bias
            return x

    def __repr__(self):
        return (
            f"SliceLinear(mode={self.mode}, trainable={tuple(self.part_T.shape)}, "
            f"frozen_A={tuple(self.part_A.shape)}, frozen_B={tuple(self.part_B.shape)})"
        )


def inject_peft(
    model: nn.Module,
    rank: int = 1,
    position: int = 0,
    bias: bool = False,
    mode: str = "column",
) -> None:
    """
    Replace nn.Linear layers with SliceLinear PEFT layers.

    Args:
        model: The model to inject SliceLinear layers into
        rank: Width of the trainable slice
        position: Starting position of the trainable slice
        bias: Whether to make bias parameters trainable
        mode: Slicing mode ('row' or 'column')

    Raises:
        ValueError: If mode is not 'row' or 'column'
        ValueError: If rank <= 0
    """
    if mode not in ("row", "column"):
        raise ValueError(f"mode must be 'row' or 'column', got '{mode}'")
    if rank <= 0:
        raise ValueError(f"rank must be positive, got {rank}")

    linear_module_names = []
    skipped_module_count = 0

    # Collect trainable linear modules
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            if any(param.requires_grad for param in module.parameters()):
                linear_module_names.append(name)
            else:
                skipped_module_count += 1

    log.info(f"Found {len(linear_module_names)} trainable linear modules")
    log.info(f"Skipped {skipped_module_count} frozen linear modules")

    if not linear_module_names:
        log.warning("No trainable linear modules found to inject SliceLinear layers")
        return

    # Freeze all parameters first
    for param in model.parameters():
        param.requires_grad = False

    # Replace linear modules with SliceLinear
    modules_to_replace = []
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear) and name in linear_module_names:
            modules_to_replace.append((name, module))

    for name, module in modules_to_replace:
        try:
            parent_name = ".".join(name.split(".")[:-1])
            attr_name = name.split(".")[-1]
            parent = model

            if parent_name:
                for attr in parent_name.split("."):
                    parent = getattr(parent, attr)

            new_layer = SliceLinear(
                module, rank=rank, position=position, bias=bias, mode=mode
            )
            setattr(parent, attr_name, new_layer)
            log.debug(f"Replaced {name} with SliceLinear")

        except Exception as e:
            log.error(f"Failed to replace module {name}: {e}")
            raise


def merge_slice_linear(layer: SliceLinear) -> nn.Linear:
    """
    Merge a SliceLinear (row/column) layer back into nn.Linear.
    """
    # print("merge_slice_linear", layer)

    # Preserve the original dtype
    original_dtype = layer.part_T.dtype

    weight = (
        torch.cat([layer.part_A, layer.part_T, layer.part_B], dim=layer.axis)
        .detach()
        .contiguous()
    )
    out_features, in_features = weight.shape
    new_linear = nn.Linear(in_features, out_features, bias=layer.bias is not None)

    # Ensure the new linear layer uses the same dtype
    new_linear = new_linear.to(dtype=original_dtype)
    new_linear.weight.data.copy_(weight.to(dtype=original_dtype))

    if layer.bias is not None:
        new_linear.bias.data.copy_(layer.bias.detach().to(dtype=original_dtype))

    # Update the parameter count
    new_linear.apply(lambda module: module._parameters["weight"].requires_grad_(True))

    return new_linear


def setup_slice_training(
    model: nn.Module,
    training_args,
    split_index: int,
    static_flag: bool,
    row_update_count: int,
    column_update_count: int,
) -> Tuple[bool, int, int]:
    """
    Setup slice training for the current split by determining parameters and injecting PEFT.

    Args:
        model: The model to setup slice training for
        training_args: Training arguments containing slice training configuration
        split_index: Current training split index
        static_flag: Whether slice training is in static mode
        row_update_count: Current count of row updates
        column_update_count: Current count of column updates

    Returns:
        tuple: (updated_static_flag, updated_row_update_count, updated_column_update_count)

    Raises:
        ValueError: If slice training is not enabled or invalid configuration
    """

    if static_flag:
        log.info("Slice training is in static mode, skipping setup")
        return static_flag, row_update_count, column_update_count

    rank = training_args.slice_train_rank
    if rank <= 0:
        raise ValueError(f"slice_train_rank must be positive, got {rank}")

    # Determine slice position and mode based on training configuration
    if training_args.slice_train_mode in ("row", "column"):
        position = (split_index % training_args.slice_train_repeat_phase) * rank
        mode = training_args.slice_train_mode
    elif training_args.slice_train_mode == "row_column":
        # Alternate between row and column modes
        if (split_index // training_args.slice_train_change_phase) % 2 == 0:
            mode = "row"
            position = row_update_count * rank
            row_update_count += 1
        else:
            mode = "column"
            position = column_update_count * rank
            column_update_count += 1
    else:
        raise ValueError(f"Invalid slice_train_mode: {training_args.slice_train_mode}")

    log.info(
        f"Split {split_index}: Configuring slice training with rank={rank}, "
        f"position={position}, mode={mode}"
    )

    try:
        inject_peft(
            model,
            rank=rank,
            position=position,
            bias=training_args.slice_train_bias,
            mode=mode,
        )
    except Exception as e:
        log.error(f"Failed to inject PEFT for split {split_index}: {e}")
        raise

    # Update static flag if configured for static slice training
    if training_args.slice_train_static:
        static_flag = True
        log.info("Slice training set to static mode")

    return static_flag, row_update_count, column_update_count


def restore_layers(model: nn.Module):
    """
    Replace all SliceLinear layers with merged nn.Linear layers.
    """
    visited = set()
    queue = deque([model])

    while queue:
        module = queue.popleft()
        module_id = id(module)

        # Skip if we've already processed this module to avoid infinite loops
        if module_id in visited:
            continue
        visited.add(module_id)

        # Process children
        for name, child_module in list(module.named_children()):
            if isinstance(child_module, SliceLinear):
                log.info(f"Restoring SliceLinear layer: {name}")
                merged_layer = merge_slice_linear(child_module)
                setattr(module, name, merged_layer)
                # Clean up the old layer to free memory
                del child_module
            else:
                # Only add to queue if not already visited
                if id(child_module) not in visited:
                    queue.append(child_module)

    # Force garbage collection to free up memory
    gc.collect()
