import torch
import torch.nn as nn
import torch.nn.functional as F
import inspect
import logging
import os
import safetensors

from transformers import AutoConfig
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.modeling_utils import PreTrainedModel
from transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM

from .sequence_mixing_layers import create_sequence_mixing_layer
from .sequence_mixing_utils import (
    SequenceMixingConfig,
    soft_cross_entropy,
    SequenceMixingCausalLMOutputWithPast,
)
from .base_logging_models import Logger, LoggingModel

from typing import Optional, List, Tuple, Union, Dict, Callable, Any


logger = logging.getLogger(__name__)


class SequenceMixingForCausalLM(Qwen2ForCausalLM, LoggingModel):

    config_class = SequenceMixingConfig

    def __init__(
        self,
        config: SequenceMixingConfig,
        metrics_logger: Optional[Logger] = None,
        custom_loss_fn: Optional[Callable] = None,
    ):
        PreTrainedModel.__init__(self=self, config=config)

        resolved_dtype: Optional[torch.dtype] = None
        if isinstance(config.torch_dtype, str):
            resolved_dtype = getattr(torch, config.torch_dtype, None)
        elif isinstance(config.torch_dtype, torch.dtype):
            resolved_dtype = config.torch_dtype

        self.reinitialize_base_model = config.reinitialize_base_model
        self._base_model = self._load_base_model(
            config.base_model_name_or_path, resolved_dtype
        )

        self.attention_module_paths = list(config.attention_module_paths)

        self.freeze_base_model = config.freeze_base_model
        self.loss_type = config.loss_type
        self.per_layer_loss_coef = config.per_layer_loss_coef
        self.distillation_loss_coef = config.distillation_loss_coef
        self.next_token_loss_coef = config.next_token_loss_coef
        self.mask_per_layer_losses = config.mask_per_layer_losses
        self.use_additional_features = config.use_additional_features

        self.forward_mode = config.forward_mode
        if config.forward_mode == "sequence_mixing_training":
            raise ValueError(
                "sequence_mixing_training is deprecated due to training time "
                "inefficiencies"
            )

        self.module_path_to_key = {
            module_path: module_path.replace(".", "_")
            for module_path in self.attention_module_paths
        }

        base_cfg = self._base_model.config
        self.sequence_mixing_layers = nn.ModuleDict()
        for layer_idx, module_path in enumerate(self.attention_module_paths):
            sanitized_key = self.module_path_to_key[module_path]
            layer = create_sequence_mixing_layer(
                base_cfg,
                config.sequence_mixing_type,
                layer_idx,
                add_features_dim=(
                    base_cfg.hidden_size if self.use_additional_features else None
                ),
                context_extension_attn_impl=config.context_extension_attn_impl,
                context_extension_attn_params=(config.context_extension_attn_params),
                context_extension_nope_params=(config.context_extension_nope_params),
            )
            self.sequence_mixing_layers[sanitized_key] = layer

        self._original_attention_modules = self._get_original_attention_modules(
            self.attention_module_paths
        )

        if config.copy_attention_weights:
            self._copy_attention_weights()

        if self.freeze_base_model:
            logger.info("Freezing base model parameters...")
            self._freeze_base_model()

        if self.forward_mode == "sequence_mixing_training":
            if self.per_layer_loss_coef > 0.0:
                self._attn_hooks = self._register_attention_hooks()

        self._replace_attention_modules()

        if custom_loss_fn is not None:
            self._base_model.loss_function = custom_loss_fn

        self.metrics_logger = metrics_logger
        if self.metrics_logger is not None:
            self.metrics_logger.register_model(self)

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):

        config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
        model = cls(config)

        state_dict_path = os.path.join(
            pretrained_model_name_or_path, "model.safetensors"
        )

        state_dict = safetensors.torch.load_file(state_dict_path)

        if model.freeze_base_model:
            model.sequence_mixing_layers.load_state_dict(state_dict)
        else:

            base_model_state_dict = {}
            seq_mix_state_dict = {}

            base_model_param_names = set(
                name for name, _ in model._base_model.named_parameters()
            )
            seq_mix_param_names = set(
                name for name, _ in model.sequence_mixing_layers.named_parameters()
            )

            for key, value in state_dict.items():
                if key in base_model_param_names:
                    base_model_state_dict[key] = value
                elif key in seq_mix_param_names:
                    seq_mix_state_dict[key] = value
                else:
                    logger.warning(f"Unknown parameter key: {key}")

            model.sequence_mixing_layers.load_state_dict(
                seq_mix_state_dict, strict=False
            )
            model._base_model.load_state_dict(base_model_state_dict, strict=False)

        resolved_dtype: Optional[torch.dtype] = None
        if isinstance(config.torch_dtype, str):
            resolved_dtype = getattr(torch, config.torch_dtype, None)
        elif isinstance(config.torch_dtype, torch.dtype):
            resolved_dtype = config.torch_dtype

        model.sequence_mixing_layers.to(resolved_dtype)

        return model

    def set_context_extension_mode(
        self,
        context_extension_attn_impl,
        context_extension_attn_params,
        context_extension_nope_params,
    ):
        for layer in self.sequence_mixing_layers.values():
            layer.set_context_extension_mode(
                context_extension_attn_impl,
                context_extension_attn_params,
                context_extension_nope_params,
            )

    def _validate_model_kwargs(self, model_kwargs: dict[str, Any]):

        pass

    def set_softmax_scale(self, softmax_scale: float):
        for layer in self.sequence_mixing_layers.values():
            layer.set_softmax_scale(softmax_scale)

    def get_input_embeddings(self):
        return self._base_model.get_input_embeddings()

    def set_input_embeddings(self, value):
        return self._base_model.set_input_embeddings(value)

    def prepare_inputs_for_generation(self, *args, **kwargs):
        return self._base_model.prepare_inputs_for_generation(*args, **kwargs)

    def can_generate(self):
        return True

    def state_dict(self, *args, **kwargs):

        if self.freeze_base_model:
            return self.sequence_mixing_layers.state_dict(*args, **kwargs)
        else:

            combined_state_dict = {}

            base_state_dict = {
                k: v.clone()
                for k, v in self._base_model.state_dict(*args, **kwargs).items()
            }
            combined_state_dict.update(base_state_dict)

            seq_mix_state_dict = self.sequence_mixing_layers.state_dict(*args, **kwargs)
            combined_state_dict.update(seq_mix_state_dict)

            return combined_state_dict

    def load_state_dict(self, state_dict, strict=True, assign=False):
        if self.freeze_base_model:
            return self.sequence_mixing_layers.load_state_dict(
                state_dict, strict, assign
            )
        else:

            base_model_state_dict = {}
            seq_mix_state_dict = {}

            base_model_param_names = set(
                name for name, _ in self._base_model.named_parameters()
            )
            seq_mix_param_names = set(
                name for name, _ in self.sequence_mixing_layers.named_parameters()
            )

            for key, value in state_dict.items():
                if key in base_model_param_names:
                    base_model_state_dict[key] = value
                elif key in seq_mix_param_names:
                    seq_mix_state_dict[key] = value
                else:
                    logger.warning(f"Unknown parameter key: {key}")

            self.sequence_mixing_layers.load_state_dict(
                seq_mix_state_dict, strict=False, assign=assign
            )
            return self._base_model.load_state_dict(
                base_model_state_dict, strict=False, assign=assign
            )

    def save_base_model(self, save_directory: str, **kwargs):

        self._base_model.save_pretrained(save_directory, **kwargs)

    def forward(
        self,
        *args,
        forward_mode: Optional[str] = None,
        **kwargs,
    ) -> Union[Tuple, CausalLMOutputWithPast]:

        return self._base_model(*args, **kwargs)

    def _load_base_model(
        self, base_model_name_or_path: str, resolved_dtype: Optional[torch.dtype] = None
    ) -> PreTrainedModel:

        if "qwen2" in base_model_name_or_path.lower():
            from transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM
            from transformers.models.qwen2.configuration_qwen2 import Qwen2Config

            config = Qwen2Config.from_pretrained(
                base_model_name_or_path, trust_remote_code=False
            )

            if self.reinitialize_base_model:
                logger.info("Intializing base model from scratch...")
                base_model = Qwen2ForCausalLM(config=config).to(resolved_dtype)
            else:
                logger.info("Loading base model from pretrained weights...")
                base_model = Qwen2ForCausalLM.from_pretrained(
                    base_model_name_or_path,
                    config=config,
                    torch_dtype=resolved_dtype if resolved_dtype is not None else None,
                    trust_remote_code=False,
                )
        elif "qwen3" in base_model_name_or_path.lower():
            from transformers.models.qwen3.modeling_qwen3 import Qwen3ForCausalLM
            from transformers.models.qwen3.configuration_qwen3 import Qwen3Config

            config = Qwen3Config.from_pretrained(
                base_model_name_or_path, trust_remote_code=False
            )

            if self.reinitialize_base_model:
                logger.info("Intializing base model from scratch...")
                base_model = Qwen3ForCausalLM(config=config).to(resolved_dtype)
            else:
                logger.info("Loading base model from pretrained weights...")
                base_model = Qwen3ForCausalLM.from_pretrained(
                    base_model_name_or_path,
                    config=config,
                    torch_dtype=resolved_dtype if resolved_dtype is not None else None,
                    trust_remote_code=False,
                )
        elif ("llama" in base_model_name_or_path.lower()) or (
            "smollm" in base_model_name_or_path.lower()
        ):
            from transformers.models.llama.modeling_llama import LlamaForCausalLM
            from transformers.models.llama.configuration_llama import LlamaConfig

            config = LlamaConfig.from_pretrained(
                base_model_name_or_path, trust_remote_code=False
            )
            if self.reinitialize_base_model:
                logger.info("Intializing base model from scratch...")
                base_model = LlamaForCausalLM(config=config).to(resolved_dtype)
            else:
                logger.info("Loading base model from pretrained weights...")
                base_model = LlamaForCausalLM.from_pretrained(
                    base_model_name_or_path,
                    config=config,
                    torch_dtype=resolved_dtype if resolved_dtype is not None else None,
                    trust_remote_code=False,
                )
        else:
            raise NotImplementedError(
                f"Base Model {base_model_name_or_path} is currently not supported"
            )
        return base_model

    def _replace_attention_modules(self):

        for module_path in self.attention_module_paths:

            *path_parts, attr_name = module_path.split(".")
            parent = self._base_model
            for part in path_parts:
                parent = getattr(parent, part)

            sanitized_key = self.module_path_to_key[module_path]
            setattr(parent, attr_name, self.sequence_mixing_layers[sanitized_key])

    def _restore_attention_modules(self):

        for module_path in self.attention_module_paths:
            *path_parts, attr_name = module_path.split(".")
            parent = self._base_model
            for part in path_parts:
                parent = getattr(parent, part)
            sanitized_key = self.module_path_to_key[module_path]
            setattr(parent, attr_name, self._original_attention_modules[sanitized_key])

    def _get_original_attention_modules(self, attention_module_paths: List[str]):

        original_attention_modules = {}

        for module_path in attention_module_paths:
            sanitized_key = self.module_path_to_key[module_path]

            module = self._base_model
            for attr in module_path.split("."):
                module = getattr(module, attr)
            original_attention_modules[sanitized_key] = module
        return original_attention_modules

    def _freeze_base_model(self):

        for param in self._base_model.parameters():
            param.requires_grad = False

    def _copy_attention_weights(self):

        logger.info(
            "Copying attention weights from base model to sequence mixing layers..."
        )

        for module_path in self.attention_module_paths:
            sanitized_key = self.module_path_to_key[module_path]

            source_attention_module = self._base_model
            for attr in module_path.split("."):
                source_attention_module = getattr(source_attention_module, attr)

            self.sequence_mixing_layers[sanitized_key].copy_attention_weights(
                source_attention_module
            )


def create_sequence_mixing_model(
    config: SequenceMixingConfig,
    metrics_logger: Optional[Logger] = None,
):

    return SequenceMixingForCausalLM(
        config=config,
        metrics_logger=metrics_logger,
    )
