import torch
import torch.nn as nn

import torch.nn.functional as F
import logging

from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_outputs import CausalLMOutputWithPast
from dataclasses import dataclass

from .base_logging_models import Logger

from typing import Optional, List, Union, Any, Dict


logger = logging.getLogger(__name__)


def _fallback_reset(m: nn.Module):
    if hasattr(m, "reset_parameters"):
        m.reset_parameters()


def reinit_weights_in_place(model):
    if hasattr(model, "_init_weights"):
        model.apply(model._init_weights)
    elif hasattr(model, "init_weights"):
        model.init_weights()
    else:
        logger.warning(
            "Model does not have `_init_weights` or `init_weights` method. "
            "Using fallback reset for all modules."
        )
        model.apply(_fallback_reset)

    if hasattr(model, "tie_weights"):
        model.tie_weights()
    return model


class PreLayerLossLogger(Logger):
    def __init__(self, model=None):
        super().__init__(model)
        self.per_layer_losses = None

    def on_step_end(self, model_outputs):
        self.per_layer_losses = (
            model_outputs.per_layer_losses
            if hasattr(model_outputs, "per_layer_losses")
            else model_outputs[-1]
        )

    def on_log(self):
        return self.per_layer_losses


class SequenceMixingLogger(Logger):
    def __init__(self, model: Optional[Any] = None):
        super().__init__(model)
        self.loss_acc = None
        self.sm_layer_acc = None
        self.sm_layer_keys = None
        self.sm_distill_acc = None
        self.sm_next_token_acc = None
        self.steps = 0

        self.sm_layer_sum = None
        self.sm_layer_count = None

    def _ensure_layer_bufs(self, n_layers: int, keys: List[str]) -> None:
        device = (
            self.loss_acc.device
            if isinstance(self.loss_acc, torch.Tensor)
            else torch.device("cpu")
        )
        if self.sm_layer_sum is None:
            self.sm_layer_sum = torch.zeros(n_layers, device=device)
            self.sm_layer_count = torch.zeros(n_layers, device=device)
        if keys is not None:
            if self.sm_layer_keys is None:
                self.sm_layer_keys = keys
            elif self.sm_layer_keys != keys:
                raise ValueError(
                    f"Expected layer keys {self.sm_layer_keys}, but got {keys}."
                )
        elif self.sm_layer_sum.numel() != n_layers:
            old_n = self.sm_layer_sum.numel()
            raise ValueError(f"Expected {n_layers} per-layer losses, but got {old_n}.")

    def on_step_end(self, model_outputs) -> None:
        with torch.no_grad():

            loss = model_outputs.loss.detach().clone().to(torch.float)
            self.loss_acc = loss if self.loss_acc is None else self.loss_acc + loss

            sm_layer = (
                model_outputs.sequence_mixing_layer_loss.detach()
                .clone()
                .to(torch.float)
            )
            self.sm_layer_acc = (
                sm_layer if self.sm_layer_acc is None else self.sm_layer_acc + sm_layer
            )

            sm_distill = (
                model_outputs.sequence_mixing_distillation_loss.detach()
                .clone()
                .to(torch.float)
            )
            self.sm_distill_acc = (
                sm_distill
                if self.sm_distill_acc is None
                else self.sm_distill_acc + sm_distill
            )

            sm_next = (
                model_outputs.sequence_mixing_next_token_loss.detach()
                .clone()
                .to(torch.float)
            )
            self.sm_next_token_acc = (
                sm_next
                if self.sm_next_token_acc is None
                else self.sm_next_token_acc + sm_next
            )

            sm_layers = getattr(model_outputs, "sequence_mixing_layer_losses", None)
            if sm_layers is not None:
                sm_layers_keys = list(sm_layers.keys())
                sm_layers_vals = [
                    sm_layer_out.detach().clone() for sm_layer_out in sm_layers.values()
                ]
                sm_layers = torch.stack(sm_layers_vals, dim=0).to(torch.float)

                sm_layers = sm_layers.detach().clone().to(torch.float).view(-1)
                self._ensure_layer_bufs(sm_layers.numel(), keys=sm_layers_keys)

                mask = torch.isfinite(sm_layers)
                self.sm_layer_sum[mask] += sm_layers[mask]
                self.sm_layer_count[mask] += mask[mask].to(torch.float)

            self.steps += 1

    def on_log(self, flush: bool = True) -> Dict[str, float]:
        with torch.no_grad():
            if self.loss_acc is None or self.steps == 0:
                return {}

            out = {
                "loss": (self.loss_acc / self.steps).item(),
                "sequence_mixing_layer_loss": (self.sm_layer_acc / self.steps).item(),
                "sequence_mixing_distillation_loss": (
                    self.sm_distill_acc / self.steps
                ).item(),
                "sequence_mixing_next_token_loss": (
                    self.sm_next_token_acc / self.steps
                ).item(),
            }

            if self.sm_layer_sum is not None and self.sm_layer_count is not None:
                seen_mask = self.sm_layer_count > 0
                if seen_mask.any():
                    layer_avg = torch.zeros_like(self.sm_layer_sum)
                    layer_avg[seen_mask] = (
                        self.sm_layer_sum[seen_mask] / self.sm_layer_count[seen_mask]
                    )
                    idxs = torch.nonzero(seen_mask, as_tuple=False).flatten()
                    for i in idxs:
                        key = (
                            self.sm_layer_keys[i]
                            if self.sm_layer_keys
                            else f"sm_loss_layer_{i}"
                        )
                        out[key] = float(layer_avg[i].item())

            if flush:
                self.loss_acc = None
                self.sm_layer_acc = None
                self.sm_distill_acc = None
                self.sm_next_token_acc = None
                self.steps = 0
                self.sm_layer_sum = None
                self.sm_layer_count = None

            return out


class SequenceMixingConfig(PretrainedConfig):

    model_type = "sequence_mixing_causal_lm"

    def __init__(
        self,
        base_model_name_or_path: str = None,
        attention_module_paths: List[str] = None,
        sequence_mixing_type: str = "qwen3_attention",
        torch_dtype: Optional[Union[str, torch.dtype]] = None,
        context_extension_attn_impl: Optional[str] = None,
        context_extension_attn_params: Optional[Dict[str, Any]] = None,
        context_extension_nope_params: Optional[Dict[str, Any]] = None,
        loss_type: str = "mse",
        per_layer_loss_coef: float = 0.0,
        distillation_loss_coef: float = 0.0,
        next_token_loss_coef: float = 1.0,
        mask_per_layer_losses: bool = True,
        copy_attention_weights: bool = False,
        use_additional_features: bool = False,
        forward_mode: str = "sequence_mixing",
        reinitialize_base_model: bool = False,
        freeze_base_model: bool = False,
        **kwargs,
    ):
        super().__init__(**kwargs)

        self.base_model_name_or_path = base_model_name_or_path
        self.attention_module_paths = attention_module_paths or []
        self.sequence_mixing_type = sequence_mixing_type
        self.context_extension_attn_impl = context_extension_attn_impl
        self.context_extension_attn_params = context_extension_attn_params or {}
        self.context_extension_nope_params = context_extension_nope_params or {}

        if isinstance(torch_dtype, torch.dtype):
            if torch_dtype == torch.float16:
                self.torch_dtype = "float16"
            elif torch_dtype == torch.bfloat16:
                self.torch_dtype = "bfloat16"
            elif torch_dtype == torch.float32:
                self.torch_dtype = "float32"
            else:
                self.torch_dtype = str(torch_dtype).replace("torch.", "")
        else:
            self.torch_dtype = torch_dtype

        self.loss_type = loss_type
        self.per_layer_loss_coef = per_layer_loss_coef
        self.distillation_loss_coef = distillation_loss_coef
        self.next_token_loss_coef = next_token_loss_coef
        self.mask_per_layer_losses = mask_per_layer_losses
        self.copy_attention_weights = copy_attention_weights
        self.use_additional_features = use_additional_features
        self.forward_mode = forward_mode
        self.reinitialize_base_model = reinitialize_base_model
        self.freeze_base_model = freeze_base_model


@dataclass
class SequenceMixingCausalLMOutputWithPast(CausalLMOutputWithPast):
    sequence_mixing_layer_loss: Optional[torch.FloatTensor] = None
    sequence_mixing_layer_losses: (
        Optional[torch.FloatTensor] | List[Optional[torch.FloatTensor]]
    ) = None
    sequence_mixing_distillation_loss: Optional[torch.FloatTensor] = None
    sequence_mixing_next_token_loss: Optional[torch.FloatTensor] = None


def soft_cross_entropy(logits, target_logits, loss_mask=None):
    log_p = F.log_softmax(logits[:, :-1], dim=-1)
    q = F.log_softmax(target_logits[:, :-1], dim=-1)
    loss = F.kl_div(log_p, q, reduction="none", log_target=True).sum(dim=-1)
    if loss_mask is not None:
        loss = loss[loss_mask[:, 1:]].mean()
    return loss.mean()


def fixed_cross_entropy(
    source: torch.Tensor,
    target: torch.Tensor,
    num_items_in_batch: Optional[int] = None,
    ignore_index: int = -100,
    **kwargs,
) -> torch.Tensor:
    reduction = "sum" if num_items_in_batch is not None else "mean"
    loss = nn.functional.cross_entropy(
        source, target, ignore_index=ignore_index, reduction=reduction
    )
    if reduction == "sum":
        loss = loss / num_items_in_batch
    return loss


def get_causal_lm_masked_loss_fn(
    unmasked_idxs: Optional[torch.Tensor] | int = None,
):
    def ForCausalLMMaskedLoss(
        logits,
        labels,
        vocab_size: int,
        num_items_in_batch: Optional[int] = None,
        ignore_index: int = -100,
        shift_labels: Optional[torch.Tensor] = None,
        **kwargs,
    ) -> torch.Tensor:
        logits = logits.float()

        if shift_labels is None:

            labels = nn.functional.pad(labels, (0, 1), value=ignore_index)
            shift_labels = labels[..., 1:].contiguous()

        if isinstance(unmasked_idxs, int):
            if unmasked_idxs < 0:
                raise ValueError("unmasked_idxs must be non-negative")
            else:
                logits = logits[..., -unmasked_idxs:, :].contiguous()
                shift_labels = shift_labels[..., -unmasked_idxs:].contiguous()
        else:
            logits = logits[..., unmasked_idxs, :].contiguous()
            shift_labels = shift_labels[..., unmasked_idxs].contiguous()

        logits = logits.view(-1, vocab_size)
        shift_labels = shift_labels.view(-1)
        shift_labels = shift_labels.to(logits.device)
        loss = fixed_cross_entropy(
            logits, shift_labels, num_items_in_batch, ignore_index, **kwargs
        )
        return loss

    return ForCausalLMMaskedLoss
