"""
Adapter utilities for HR2R wrapper.

This module centralizes all adapter-related logic (e.g., LoRA/multi-LoRA/cascade)
that previously lived inside `recurrent_transformer.py`, without changing any
existing runtime behavior.
"""

from __future__ import annotations

import os
from typing import Optional

import torch
import torch.nn as nn

from transformers.utils import logging

from hr2r.model.utils import get_attr_recursive

# PEFT imports (lazy optional)
try:
    from peft import LoraConfig, get_peft_model
    from peft.tuners.lora.layer import LoraLayer
    PEFT_AVAILABLE = True
except ImportError:
    PEFT_AVAILABLE = False
    LoraConfig = None  # type: ignore
    get_peft_model = None  # type: ignore


logger = logging.get_logger(__name__)


def setup_adapter(wrapper, config) -> None:
    """Initialize adapter according to `config` and attach to wrapper.

    This function mutates `wrapper` to keep all public attributes identical
    to the previous implementation.
    """
    wrapper.adapter = config.adapter
    wrapper.adapter_config = None
    wrapper.lora_iter_to_adapter = {}

    if wrapper.adapter in ["lora", "multilora"]:
        if not PEFT_AVAILABLE:
            raise ImportError(
                "PEFT library is required for LoRA support. Please install with: pip install peft"
            )
        base_grad = config.adapter_kwargs.pop("base_grad", True)
        adapter_grad = config.adapter_kwargs.pop("adapter_grad", True)
        
        wrapper.adapter_config = LoraConfig(**config.adapter_kwargs)  # type: ignore
        wrapper.simple_base_model = get_peft_model(wrapper.simple_base_model, wrapper.adapter_config)  # type: ignore
        # Reset adapter_kwargs
        config.adapter_kwargs["base_grad"] = base_grad
        config.adapter_kwargs["adapter_grad"] = adapter_grad

        if wrapper.adapter == "multilora":
            # Iteration 1 reuses the default adapter created above; deeper iterations receive their own adapters.
            wrapper.lora_iter_to_adapter = {1: "default"}
            logger.info(f"Created LoRA adapter for iteration 1: default")

            for i in range(2, wrapper.max_iter):
                adapter_name = f"iter_{i}"
                if not hasattr(wrapper.simple_base_model, "peft_config") or adapter_name not in getattr(wrapper.simple_base_model, "peft_config", {}):
                    wrapper.simple_base_model.add_adapter(adapter_name, wrapper.adapter_config)
                wrapper.lora_iter_to_adapter[i] = adapter_name
                logger.info(f"Created LoRA adapter for iteration {i}: {adapter_name}")
        
        if base_grad or not adapter_grad:
            for name, p in wrapper.simple_base_model.base_model.named_parameters():
                if "lora" in name.lower():
                    p.requires_grad = adapter_grad
                else:
                    p.requires_grad = base_grad
            # Process LoRA layers if want to freeze them
            # for module in wrapper.simple_base_model.base_model.modules():
            #     if isinstance(module, LoraLayer):
            #         for p in module.parameters():
            #             p.requires_grad = True

        logger.info(f"LoRA enabled with config: {config.adapter_kwargs}")

    elif wrapper.adapter == "cascade":
        import copy
        # Whether to allow gradients on the primary simple_base_model in cascade mode
        # Default True to preserve existing behavior
        base_grad: bool = config.adapter_kwargs.pop("base_grad", True)
        # Deep copy to create the secondary model used for subsequent iterations
        wrapper.cascade_model = copy.deepcopy(wrapper.simple_base_model)
        # Apply freezing policy on the primary model parameters
        for param in wrapper.simple_base_model.parameters():
            param.requires_grad = base_grad
        logger.info(
            f"Cascade adapter enabled: created separate models for subsequent iterations; base_grad={base_grad}"
        )
    elif wrapper.adapter == "cascaded_lora":
        import copy
        if not PEFT_AVAILABLE:
            raise ImportError(
                "PEFT library is required for cascaded_lora support. Please install with: pip install peft"
            )
        # Control gradient requirement on the primary (iteration 0) model
        base_grad: bool = config.adapter_kwargs.pop("base_grad", True)
        adapter_grad: bool = config.adapter_kwargs.pop("adapter_grad", True)
        base_trim_last_n: int = int(config.adapter_kwargs.pop("base_trim_last_n", 0))
        cascade_trim_first_n: int = int(config.adapter_kwargs.pop("cascade_trim_first_n", 0))

        # Helper utilities to locate and trim transformer layers
        def _find_layers_parent(model_obj):
            candidates = [
                ("model", "layers"),
                ("transformer", "layers"),
                ("transformer", "h"),
                ("decoder", "layers"),
                ("backbone", "layers"),
            ]
            for parent_name, layers_name in candidates:
                if hasattr(model_obj, parent_name):
                    parent = getattr(model_obj, parent_name)
                    if hasattr(parent, layers_name):
                        layers = getattr(parent, layers_name)
                        if isinstance(layers, nn.ModuleList):
                            return parent, layers_name, layers
            return None, None, None

        def _trim_last_n_layers(model_obj, n: int) -> int:
            if n <= 0:
                return 0
            parent, layers_attr, layers = _find_layers_parent(model_obj)
            if layers is None:
                logger.warning("cascaded_lora: could not locate transformer layers to trim last N")
                return 0
            total = len(layers)
            n = min(n, total)
            if n > 0:
                new_layers = nn.ModuleList(list(layers)[:-n])
                setattr(parent, layers_attr, new_layers)
            return n

        def _trim_first_n_layers(model_obj, n: int) -> int:
            if n <= 0:
                return 0
            parent, layers_attr, layers = _find_layers_parent(model_obj)
            if layers is None:
                logger.warning("cascaded_lora: could not locate transformer layers to trim first N")
                return 0
            total = len(layers)
            n = min(n, total)
            if n > 0:
                new_layers = nn.ModuleList(list(layers)[n:])
                setattr(parent, layers_attr, new_layers)
            return n

        # 1) Trim PRIMARY model's last X layers (lm_head remains untouched)
        wrapper.cascade_model = copy.deepcopy(wrapper.simple_base_model)
        
        trimmed_last = _trim_last_n_layers(wrapper.simple_base_model, base_trim_last_n)

        # 2) Create secondary model AFTER primary trimming, then trim its first Y layers
        trimmed_first = _trim_first_n_layers(wrapper.cascade_model, cascade_trim_first_n)

        # Freeze/enable primary model params based on base_grad
        for param in wrapper.simple_base_model.parameters():
            param.requires_grad = base_grad

        # Build LoRA config from remaining kwargs and wrap ONLY the secondary model
        wrapper.adapter_config = LoraConfig(**config.adapter_kwargs)  # type: ignore
        wrapper.cascade_model = get_peft_model(wrapper.cascade_model, wrapper.adapter_config)  # type: ignore
        # Reset adapter_kwargs
        config.adapter_kwargs["base_grad"] = base_grad
        config.adapter_kwargs["adapter_grad"] = adapter_grad
        config.adapter_kwargs["base_trim_last_n"] = base_trim_last_n
        config.adapter_kwargs["cascade_trim_first_n"] = cascade_trim_first_n

        logger.info(
            f"Cascaded-LoRA adapter enabled: primary trimmed_last={trimmed_last}, secondary trimmed_first={trimmed_first}; primary base_grad={base_grad}"
        )
    else:
        logger.info("Adapter disabled")


def configure_lora_for_iteration(wrapper, iter_depth: int) -> None:
    """Enable/disable and/or switch LoRA adapters per-iteration.

    Mirrors previous `_configure_lora_for_iteration` behavior.
    """
    if not wrapper.adapter in ["lora", "multilora"]:
        return

    if iter_depth == 0:
        # Disable all LoRA parameters for the first iteration
        wrapper.simple_base_model.base_model.disable_adapter_layers()
    elif iter_depth > 0:
        if wrapper.adapter == "multilora":
            # Enable LoRA and switch to the adapter corresponding to this iteration
            wrapper.simple_base_model.base_model.enable_adapter_layers()
            adapter_name = wrapper.lora_iter_to_adapter.get(iter_depth, wrapper.lora_iter_to_adapter.get(1, "default"))
            wrapper.simple_base_model.set_adapter(adapter_name)
        elif wrapper.adapter == "lora":
            # For single LoRA, enable adapter layers for subsequent iterations
            wrapper.simple_base_model.base_model.enable_adapter_layers()


def get_embed_tokens(wrapper):
    """Return the embedding layer from the base model, handling LoRA wrapping."""
    if wrapper.adapter in ["lora", "multilora"]:
        return get_attr_recursive(wrapper.simple_base_model.base_model.model, wrapper.embedding_key)
    return get_attr_recursive(wrapper.simple_base_model, wrapper.embedding_key)


def select_model_for_iteration(wrapper, iter_depth: int):
    """Return the underlying model to use for this iteration (cascade support)."""
    if wrapper.adapter in ["cascade", "cascaded_lora"]:
        return wrapper.simple_base_model if iter_depth == 0 else wrapper.cascade_model
    return wrapper.simple_base_model


def move_to(wrapper, *args, **kwargs) -> None:
    """Move adapter-owned modules with `.to()` semantics."""
    if hasattr(wrapper, "cascade_model") and wrapper.cascade_model is not None:
        wrapper.cascade_model = wrapper.cascade_model.to(*args, **kwargs)


def cuda(wrapper, device: Optional[int] = None) -> None:
    """Move adapter-owned modules to CUDA."""
    if hasattr(wrapper, "cascade_model") and wrapper.cascade_model is not None:
        wrapper.cascade_model = wrapper.cascade_model.cuda(device)


def cpu(wrapper) -> None:
    """Move adapter-owned modules to CPU."""
    if hasattr(wrapper, "cascade_model") and wrapper.cascade_model is not None:
        wrapper.cascade_model = wrapper.cascade_model.cpu()


def save_adapter(wrapper, save_directory: str, **kwargs) -> None:
    """Save adapter-related weights and, where appropriate, the base model.

    Mirrors previous logic in `HR2RForCausalLM.save_pretrained` for adapter branches.
    """
    if wrapper.adapter in ["lora", "multilora"]:
        # Save LoRA adapter(s)
        lora_dir = os.path.join(save_directory, "lora")
        os.makedirs(lora_dir, exist_ok=True)
        wrapper.simple_base_model.save_pretrained(lora_dir, **kwargs)

        # Directly save with cleaned keys by temporarily overriding state_dict method
        base_model = wrapper.simple_base_model.base_model.model
        original_state_dict = base_model.state_dict

        def cleaned_state_dict():
            """Return state_dict with cleaned keys (remove .base_layer)"""
            state_dict = original_state_dict()
            cleaned_dict = {}
            for key, value in state_dict.items():
                if 'lora' in key.lower():  # skip lora weights
                    continue
                cleaned_key = key.replace('.base_layer', '')
                cleaned_dict[cleaned_key] = value
            return cleaned_dict

        base_model.state_dict = cleaned_state_dict
        try:
            base_model.save_pretrained(save_directory, **kwargs)
        finally:
            base_model.state_dict = original_state_dict

        logger.info(f"Saving LoRA adapter and cleaned base model to {save_directory}")

    elif wrapper.adapter == "cascade":
        # Save both models separately
        primary_model_dir = save_directory
        cascade_model_dir = os.path.join(save_directory, "cascade_model")

        os.makedirs(primary_model_dir, exist_ok=True)
        os.makedirs(cascade_model_dir, exist_ok=True)

        # Save primary model (for first iteration)
        wrapper.simple_base_model.save_pretrained(primary_model_dir, **kwargs)
        # Save secondary model (for subsequent iterations)
        wrapper.cascade_model.save_pretrained(cascade_model_dir, **kwargs)

        logger.info(f"Saved primary model to {primary_model_dir}")
        logger.info(f"Saved secondary model to {cascade_model_dir}")

    elif wrapper.adapter == "cascaded_lora":
        # Save primary model (full) and only the LoRA adapter for the secondary model
        primary_model_dir = save_directory
        cascade_lora_dir = os.path.join(save_directory, "cascade_lora")

        os.makedirs(primary_model_dir, exist_ok=True)
        os.makedirs(cascade_lora_dir, exist_ok=True)

        # Save primary full model for first iteration
        wrapper.simple_base_model.save_pretrained(primary_model_dir, **kwargs)

        # Save secondary model's LoRA adapter only
        try:
            wrapper.cascade_model.save_pretrained(cascade_lora_dir, **kwargs)
            logger.info(f"Saved cascaded LoRA adapter to {cascade_lora_dir}")
        except Exception as e:
            logger.warning(f"Failed to save cascaded LoRA adapter: {e}")

        logger.info(f"Saved primary model to {primary_model_dir}")

    else:
        # Adapter disabled: directly save the base model
        wrapper.simple_base_model.save_pretrained(save_directory, **kwargs)
        logger.info(f"Saving base model to {save_directory}")


def load_adapter(wrapper, pretrained_model_name_or_path: str, final_config, *args, **kwargs) -> None:
    """Reload adapter-specific weights during `from_pretrained`.

    Mirrors previous logic for LoRA reload and cascade secondary model.
    """
    from transformers import AutoModelForCausalLM

    # Reload LoRA weights if needed
    if wrapper.adapter in ["lora", "multilora"]:
        logger.info("Reloading LoRA adapters from checkpoint after initialization")
        adapter_path = os.path.join(pretrained_model_name_or_path, "lora")
        base_grad = final_config.adapter_kwargs.pop("base_grad", True)
        adapter_grad = final_config.adapter_kwargs.pop("adapter_grad", True)
        if wrapper.adapter == "lora":
            wrapper.simple_base_model.load_adapter(adapter_path, adapter_name="default")
            logger.info(f"Reloaded LoRA adapter from {adapter_path}")
        elif wrapper.adapter == "multilora":
            for iter_num, adapter_name in wrapper.lora_iter_to_adapter.items():
                this_adapter_path = adapter_path
                if adapter_name != "default":
                    this_adapter_path = os.path.join(adapter_path, adapter_name)
                try:
                    wrapper.simple_base_model.load_adapter(this_adapter_path, adapter_name=adapter_name)
                    logger.info(f"Reloaded LoRA adapter '{adapter_name}' for iteration {iter_num}")
                except Exception as e:
                    logger.warning(f"Failed to reload adapter '{adapter_name}': {e}")
        # Set gradients based on parameter names: LoRA params get adapter_grad, others get base_grad
        for name, p in wrapper.simple_base_model.named_parameters():
            if "lora" in name.lower():
                p.requires_grad = adapter_grad
            else:
                p.requires_grad = base_grad

    # Handle cascade adapter mode - load secondary model
    if final_config.adapter == "cascade":
        try:
            cascade_model_path = os.path.join(pretrained_model_name_or_path, "cascade_model")
            cascade_model = AutoModelForCausalLM.from_pretrained(
                cascade_model_path, *args, **kwargs
            )
            wrapper.cascade_model = cascade_model
            logger.info(f"Loaded secondary model from {cascade_model_path} for cascade adapter mode")
            base_grad = final_config.adapter_kwargs.pop("base_grad", True)
            for p in wrapper.cascade_model.parameters():
                p.requires_grad = base_grad
        except Exception as e:
            logger.warning(f"Failed to load secondary model from {cascade_model_path} for cascade adapter mode: {e}")
    elif final_config.adapter == "cascaded_lora":
        # Reconstruct the secondary model, wrap with LoRA, then load the saved adapter
        try:
            import copy
            # Apply the same trimming as in setup
            base_trim_last_n: int = int(getattr(final_config, "adapter_kwargs", {}).get("base_trim_last_n", 0))
            cascade_trim_first_n: int = int(getattr(final_config, "adapter_kwargs", {}).get("cascade_trim_first_n", 0))

            # Helper utilities to locate and trim transformer layers
            def _find_layers_parent(model_obj):
                candidates = [
                    ("model", "layers"),
                    ("transformer", "layers"),
                    ("transformer", "h"),
                    ("decoder", "layers"),
                    ("backbone", "layers"),
                ]
                for parent_name, layers_name in candidates:
                    if hasattr(model_obj, parent_name):
                        parent = getattr(model_obj, parent_name)
                        if hasattr(parent, layers_name):
                            layers = getattr(parent, layers_name)
                            if isinstance(layers, nn.ModuleList):
                                return parent, layers_name, layers
                return None, None, None

            def _trim_last_n_layers(model_obj, n: int) -> int:
                if n <= 0:
                    return 0
                parent, layers_attr, layers = _find_layers_parent(model_obj)
                if layers is None:
                    return 0
                total = len(layers)
                n = min(n, total)
                if n > 0:
                    new_layers = nn.ModuleList(list(layers)[:-n])
                    setattr(parent, layers_attr, new_layers)
                return n

            def _trim_first_n_layers(model_obj, n: int) -> int:
                if n <= 0:
                    return 0
                parent, layers_attr, layers = _find_layers_parent(model_obj)
                if layers is None:
                    return 0
                total = len(layers)
                n = min(n, total)
                if n > 0:
                    new_layers = nn.ModuleList(list(layers)[n:])
                    setattr(parent, layers_attr, new_layers)
                return n

            # Ensure primary model is trimmed the same way as in setup
            _trim_last_n_layers(wrapper.simple_base_model, base_trim_last_n)

            # Recreate the secondary model as a deep copy of the (trimmed) primary, then trim first N
            wrapper.cascade_model = copy.deepcopy(wrapper.simple_base_model)
            _trim_first_n_layers(wrapper.cascade_model, cascade_trim_first_n)

            if not PEFT_AVAILABLE:
                raise ImportError(
                    "PEFT library is required for cascaded_lora support. Please install with: pip install peft"
                )

            # Build LoRA config from final_config; fall back to default if missing
            adapter_kwargs = getattr(final_config, "adapter_kwargs", {}) or {}
            # Remove non-LoRA keys if present
            adapter_kwargs = dict(adapter_kwargs)
            adapter_kwargs.pop("base_grad", None)
            adapter_kwargs.pop("base_trim_last_n", None)
            adapter_kwargs.pop("cascade_trim_first_n", None)

            lora_cfg = LoraConfig(**adapter_kwargs)  # type: ignore
            wrapper.cascade_model = get_peft_model(wrapper.cascade_model, lora_cfg)  # type: ignore

            # Freeze its base parameters
            if hasattr(wrapper.cascade_model, "base_model"):
                try:
                    for p in wrapper.cascade_model.base_model.parameters():
                        p.requires_grad = False
                except Exception:
                    pass

            # Load the adapter weights
            adapter_path = os.path.join(pretrained_model_name_or_path, "cascade_lora")
            try:
                wrapper.cascade_model.load_adapter(adapter_path, adapter_name="default")
                logger.info(f"Reloaded cascaded LoRA adapter from {adapter_path}")
            except Exception as e:
                logger.warning(f"Failed to reload cascaded LoRA adapter from {adapter_path}: {e}")
        except Exception as e:
            logger.warning(f"Failed to reconstruct cascaded LoRA secondary model: {e}")


