import math
import re
import warnings
from collections import defaultdict

import torch
from torch import nn
from transformers import Conv1D

from .layer import Linear, GroupLoRALinear
from peft.tuners.lora import LoraModel
from peft.tuners.tuners_utils import BaseTunerLayer
from .config import GroupLoraConfig

def extract_after_digit(key):
    match = re.search(r'\.\d+\.(.*)', key)
    return match.group(1) if match else key


class GroupLoraModel(LoraModel):
    def __init__(self, model, config, adapter_name, low_cpu_mem_usage: bool = False) -> None:
        super().__init__(model, config, adapter_name, low_cpu_mem_usage=low_cpu_mem_usage)

    def _pre_injection_hook(self, model: nn.Module, config: GroupLoraConfig, adapter_name: str) -> None:
        self.lora_groups = nn.ModuleDict()
        self._init_shared_lora_groups(model, config)
        self.layer_counters = defaultdict(int)

    def _init_shared_lora_groups(self, model, config):
        target_modules = set(config.target_modules)
        feature_dims = {}

        for name, module in model.named_modules():
            if '.0.' not in name:
                continue
            if any(key in name for key in target_modules):
                if isinstance(module, (nn.Linear, Conv1D)):
                    if isinstance(module, nn.Linear):
                        in_feat, out_feat = module.in_features, module.out_features
                    else:
                        in_feat, out_feat = module.weight.shape
                    module_key = name.split('.0.')[-1].replace('.', '_')
                    # print(f"{module_key}: {in_feat} -> {out_feat}")
                    feature_dims[module_key] = (in_feat, out_feat)

        for module_key, (in_feat, out_feat) in feature_dims.items():
            num_groups = math.ceil(config.num_layers / config.group_size)
            self.lora_groups[module_key] = nn.ModuleList([
                GroupLoRALinear(in_feat, out_feat, config.r)
                for _ in range(num_groups)
            ])
        # print(self.lora_groups)

    def _create_and_replace(
        self,
        lora_config,
        adapter_name,
        target,
        target_name,
        parent,
        current_key,
    ):
        module_key = extract_after_digit(current_key).replace('.', '_')
        layer_idx = self.layer_counters[module_key]
        shared_group = self.lora_groups[module_key]
        kwargs = {
            "shared_group": shared_group,
            "layer_idx": layer_idx,
        }
        new_module = self._create_new_module(lora_config, adapter_name, target, **kwargs)
        if adapter_name not in self.active_adapters:
            # adding an adapter: it is not automatically trainable
            new_module.requires_grad_(False)
        self._replace_module(parent, target_name, new_module, target)
        self.layer_counters[module_key] += 1

    @staticmethod
    def _create_new_module(lora_config, adapter_name, target, **kwargs):
        if isinstance(target, BaseTunerLayer):
            target_base_layer = target.get_base_layer()
        else:
            target_base_layer = target

        if isinstance(target_base_layer, torch.nn.Linear):
            if lora_config.fan_in_fan_out:
                warnings.warn(
                    "fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. "
                    "Setting fan_in_fan_out to False."
                )
                lora_config.fan_in_fan_out = False
        elif isinstance(target_base_layer, Conv1D):
            if not lora_config.fan_in_fan_out:
                warnings.warn(
                    "fan_in_fan_out is set to False but the target module is `Conv1D`. "
                    "Setting fan_in_fan_out to True."
                )
                lora_config.fan_in_fan_out = True
        else:
            raise ValueError(
                f"Target module {target} is not supported. "
                f"Currently, only `torch.nn.Linear` and `Conv1D` are supported."
            )
        new_module = Linear(target, adapter_name, lora_config, **kwargs)

        if new_module is None:
            raise ValueError(
                f"Target module {target} is not supported. Currently only nn.Linear is supported."
            )
        return new_module


    def _prepare_model(self, peft_config: GroupLoraConfig, model: nn.Module):
        pass

    def __getattr__(self, name: str):
        """Forward missing attributes to the wrapped module."""
        try:
            return super().__getattr__(name)  # defer to nn.Module's logic aa
        except AttributeError:
            if name == "model":  # see #1892: prevent infinite recursion if class is not initialized
                raise
            return getattr(self.model, name)

