import torch
import torch.nn as nn
import torch.nn.functional as F
import warnings
from dataclasses import dataclass, field
from typing import List, Optional, Tuple


# ---------------------------
#  MoSA Autograd (presort + sum with fallback)
# ---------------------------

class MoSASeededFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, weight, bias, deltas, group_indices, sorted_order):
        """
        x:               [..., in_features]
        weight:          [out_features, in_features]  (frozen)
        bias:            [out_features] or None
        deltas:          [K]  (trainable scalars per group)
        group_indices:   [out_features, in_features] (long in [0, K-1])
        sorted_order:    [N] argsort of group_indices.flatten(), so that same-group
                         elements become contiguous after reindexing.
        """
        ctx.save_for_backward(x, weight, bias, deltas, group_indices, sorted_order)

        main_output = F.linear(x, weight, bias)
        delta_W = deltas[group_indices.to(torch.long)]
        delta_output = F.linear(x, delta_W)
        return main_output + delta_output

    @staticmethod
    def backward(ctx, grad_output):
        (x, weight, bias, deltas, group_indices, sorted_order) = ctx.saved_tensors

        grad_x = grad_weight = grad_bias = grad_deltas = None

        # dL/dx
        if ctx.needs_input_grad[0]:
            delta_W = deltas[group_indices.to(torch.long)]
            w_eff_t = (weight + delta_W).t()
            grad_x = F.linear(grad_output.to(w_eff_t.dtype), w_eff_t)

        # dL/dW (base weight is frozen; we still compute grad for completeness)
        if ctx.needs_input_grad[1]:
            go2d = grad_output.reshape(-1, grad_output.shape[-1])    # [B*, out]
            x2d  = x.reshape(-1, x.shape[-1])                        # [B*, in]
            grad_weight = go2d.t().matmul(x2d)                       # [out, in]

        # dL/dbias
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.reshape(-1, grad_output.shape[-1]).sum(dim=0)

        # dL/ddeltas
        if ctx.needs_input_grad[3]:
            go2d = grad_output.reshape(-1, grad_output.shape[-1])    # [B*, out]
            x2d  = x.reshape(-1, x.shape[-1])                        # [B*, in]
            grad_delta_w = go2d.t().matmul(x2d)                      # [out, in]

            flat_grad = grad_delta_w.reshape(-1)                     # [N]
            N = int(flat_grad.numel())
            K = int(deltas.numel())
            if K <= 0:
                raise RuntimeError("Number of MoSA groups (K) must be > 0.")
            flat_groups = group_indices.reshape(-1).to(torch.long)   # [N]

    
            counts = torch.bincount(flat_groups, minlength=K)
            B = N // K
            is_near_balanced = (counts.min().item() >= B) and (counts.max().item() <= B + 1)

            if is_near_balanced:
                grad_sorted = flat_grad.index_select(0, sorted_order)
                r = N % K
                if r == 0:
                    grad_deltas_fp32 = grad_sorted.view(K, B).sum(dim=1)
                else:
                    head_len = (B + 1) * r
                    head = grad_sorted[:head_len].view(r, B + 1).sum(dim=1)
                    if K - r > 0:
                        tail = grad_sorted[head_len:].view(K - r, B).sum(dim=1)
                        grad_deltas_fp32 = torch.cat([head, tail], dim=0)
                    else:
                        grad_deltas_fp32 = head
            else:
      
                grad_deltas_fp32 = torch.zeros(K, device=flat_grad.device, dtype=flat_grad.dtype)
                grad_deltas_fp32.index_add_(0, flat_groups, flat_grad)

            grad_deltas = grad_deltas_fp32.to(deltas.dtype)

        return grad_x, grad_weight, grad_bias, grad_deltas, None, None


# ---------------------------
#  Config & helpers
# ---------------------------

@dataclass
class MoSAConfig:
    """
    Configuration class for Group-Share Adaptation (MoSA).
    """
    mosa_equivalent_rank: int = field(default=8, metadata={"help": "Equivalent LoRA rank 'r' to align parameter counts."})
    target_modules: Optional[List[str]] = field(default=None, metadata={"help": "Module name suffixes to apply MoSA to."})
    merge_weights: bool = field(default=False, metadata={"help": "Whether to merge MoSA weights with the base model in eval mode."})
    fan_in_fan_out: bool = field(default=False, metadata={"help": "True if the layer stores weight as (fan_in, fan_out)."})
    bias: str = field(default="none", metadata={"help": "Bias type: 'none' | 'all' | 'mosa_only'."})
    modules_to_save: Optional[List[str]] = field(default=None, metadata={"help": "Extra modules to keep trainable besides MoSA."})
    grouping_strategy: str = field(
        default='balanced_seeded',
        metadata={"help": "Grouping strategy: 'balanced_seeded' (recommended) or 'random'."}
    )
    grouping_seed: Optional[int] = field(default=42, metadata={"help": "Random seed for seeded grouping strategies."})


def mark_only_mosa_as_trainable(model: nn.Module, bias: str = "none") -> None:
    """
    Freeze all parameters except MoSA deltas (and optionally bias).
    """
    for n, p in model.named_parameters():
        if "deltas" not in n:
            p.requires_grad = False
        else:
            p.requires_grad = True

    if bias == "none":
        return
    elif bias == "all":
        for n, p in model.named_parameters():
            if "bias" in n:
                p.requires_grad = True
    elif bias == "mosa_only":
        for m in model.modules():
            if isinstance(m, MoSALayer) and hasattr(m, "bias") and m.bias is not None:
                m.bias.requires_grad = True
    else:
        raise NotImplementedError("Unsupported bias type.")


def _create_balanced_group_indices(
    shape: Tuple[int, int], num_groups: int, generator: torch.Generator, device: torch.device
) -> torch.LongTensor:
    """
    Create a 2D tensor of group indices on target device with near-equal counts
    (each group has either B or B+1 elements; B = floor(N/K)).
    We form an index multiset with desired counts, then globally shuffle.
    After argsort by group id, the first r groups have length B+1, others B.
    """
    num_elements = shape[0] * shape[1]
    if num_groups <= 0:
        raise ValueError("num_groups must be > 0.")
    elements_per_group = num_elements // num_groups

    base_indices = torch.arange(num_groups, device=device).repeat_interleave(elements_per_group)
    remainder = num_elements % num_groups
    if remainder > 0:
        remainder_indices = torch.arange(remainder, device=device)
        full_indices_1d = torch.cat([base_indices, remainder_indices])
    else:
        full_indices_1d = base_indices

    shuffled = torch.randperm(num_elements, generator=generator, device=device)
    shuffled_groups_1d = full_indices_1d[shuffled]
    return shuffled_groups_1d.view(shape).to(torch.long)


# ---------------------------
#  Layers
# ---------------------------

class MoSALayer:
    """
    A lightweight base mixin for MoSA flags (not an nn.Module itself).
    """
    def __init__(self, merge_weights: bool, **kwargs):
        self.merged = False
        self.merge_weights = merge_weights


class Linear(nn.Linear, MoSALayer):

    def __init__(
        self,
        in_features: int,
        out_features: int,
        num_groups: int,
        group_indices: torch.LongTensor,
        sorted_order: torch.LongTensor,            
        fan_in_fan_out: bool = False,
        merge_weights: bool = True,
        **kwargs,
    ):
        nn.Linear.__init__(self, in_features, out_features, **kwargs)
        MoSALayer.__init__(self, merge_weights=merge_weights)

        self.fan_in_fan_out = fan_in_fan_out

       
        self.register_buffer('group_indices', group_indices.to(torch.long), persistent=False)
        self.register_buffer('sorted_order',  sorted_order.to(torch.long),  persistent=False)

        self.deltas = nn.Parameter(torch.zeros(num_groups, dtype=self.weight.dtype), requires_grad=True)
        self.weight.requires_grad = False

        if fan_in_fan_out:
            self.weight.data = self.weight.data.T

    def forward(self, x: torch.Tensor):
        return MoSASeededFunction.apply(
            x, self.weight, self.bias, self.deltas, self.group_indices, self.sorted_order
        )


# ---------------------------
#  MoSA wrapper
# ---------------------------

class MoSA(torch.nn.Module):
    """
    Main class to apply Group-Share Adaptation to a given model.
    Replaces target Linear layers with MoSA-augmented Linear layers that
    share (shape, K)-keyed group_indices and sorted_order.
    """
    def __init__(self, model: nn.Module, config: MoSAConfig):
        super().__init__()
        self.peft_config = config
        self.model = model
        self._find_and_replace()
        mark_only_mosa_as_trainable(self.model, self.peft_config.bias)

    def _find_and_replace(self):
        is_target_modules_in_base_model = False
        key_list = [key for key, _ in self.model.named_modules()]

     ）
        try:
            device = next(self.model.parameters()).device
        except StopIteration:
            device = torch.device("cpu")

        layer_configs = {}
        unique_configs = set()


        for key in key_list:
            if self.peft_config.target_modules and not any(key.endswith(t) for t in self.peft_config.target_modules):
                continue
            _, target, _ = self._get_submodules(key)
            if isinstance(target, torch.nn.Linear):
                out_features, in_features = target.out_features, target.in_features

                equivalent_rank = self.peft_config.mosa_equivalent_rank
                layer_num_groups = max(1, equivalent_rank * (in_features + out_features))

                config_tuple = ((out_features, in_features), layer_num_groups)
                layer_configs[key] = config_tuple
                unique_configs.add(config_tuple)

        unique_configs = sorted(list(unique_configs), key=lambda t: (t[0][0], t[0][1], t[1]))


        indices_bank = {}
        order_bank   = {}

        if self.peft_config.grouping_strategy == 'balanced_seeded':
            if self.peft_config.grouping_seed is None:
                raise ValueError("`grouping_seed` must be provided for 'balanced_seeded' strategy.")
            generator = torch.Generator(device=device).manual_seed(self.peft_config.grouping_seed)

            for shape, num_groups in unique_configs:
                gi = _create_balanced_group_indices(shape, num_groups, generator, device=device)
                indices_bank[(shape, num_groups)] = gi
                with torch.no_grad():
                    flat = gi.flatten()
                    _, order = torch.sort(flat)  
                order_bank[(shape, num_groups)] = order

        for key in key_list:
            if key not in layer_configs:
                continue

            if not is_target_modules_in_base_model:
                is_target_modules_in_base_model = True

            parent, target, target_name = self._get_submodules(key)
            bias_flag = target.bias is not None
            shape, layer_num_groups = layer_configs[key]
            out_features, in_features = shape

            if self.peft_config.grouping_strategy == 'random':
   
                group_indices = torch.randint(
                    0, layer_num_groups, size=shape, dtype=torch.long, device=device
                )
                with torch.no_grad():
                    flat = group_indices.flatten()
                    _, order = torch.sort(flat)
            elif self.peft_config.grouping_strategy == 'balanced_seeded':
                group_indices = indices_bank[(shape, layer_num_groups)]
                order = order_bank[(shape, layer_num_groups)]
            else:
                raise NotImplementedError(f"Grouping strategy '{self.peft_config.grouping_strategy}' is not implemented.")

            new_module = Linear(
                in_features=in_features,
                out_features=out_features,
                num_groups=layer_num_groups,
                group_indices=group_indices,
                sorted_order=order,
                bias=bias_flag,
                fan_in_fan_out=self.peft_config.fan_in_fan_out,
                merge_weights=self.peft_config.merge_weights
            )
            self._replace_module(parent, target_name, new_module, target)

        if not is_target_modules_in_base_model:
            warnings.warn(
                "No target modules were replaced by MoSA. "
                "Check `config.target_modules` suffixes."
            )

    def _get_submodules(self, key):
        parent = self.model.get_submodule(".".join(key.split(".")[:-1])) if "." in key else self.model
        target_name = key.split(".")[-1]
        target = self.model.get_submodule(key)
        return parent, target, target_name

    def _replace_module(self, parent_module, child_name, new_module, old_module):
        setattr(parent_module, child_name, new_module)

        new_module.weight = old_module.weight
        if old_module.bias is not None:
            new_module.bias = old_module.bias

        new_module.to(old_module.weight.device)

      
        if getattr(old_module, "state", None) is not None:
            new_module.state = old_module.state

    def __getattr__(self, name: str):
        try:
            return super().__getattr__(name)
        except AttributeError:
            return getattr(self.model, name)

    def forward(self, **kwargs):
        return self.model(**kwargs)
