import math
import torch

from .editable_model import EditableModel
from ..utils import _logits
from ...util.lora_layers import LoRALinear


class MORE(EditableModel):
    def __init__(self, model, config, model_constructor):
        super().__init__(model, config, model_constructor)
        if not str(self.config.device).startswith("cuda"):
            self.config.device = f"cuda:{self.config.device}"
        self._inject_lora()
        self.group_map = self._collect_groups()
        self._freeze_non_lora()

    def _inject_lora(self):
        if getattr(self.model, "_more_lora_injected", False):
            return

        more_nonlinear = getattr(self.config, "more_nonlinear", "none")
        more_layer_norm = getattr(self.config, "more_layer_norm", False)
        more_layer_norm_eps = getattr(self.config, "more_layer_norm_eps", 1e-5)

        if hasattr(self.model, "llama_model") and hasattr(self.model, "llama_proj"):
            lm_model = self.model.llama_model
            proj = self.model.llama_proj
            proj_attr = "llama_proj"
            proj_base = "llama_proj"
            group_prefix = "llama.layer"
            layer_attr = "mlp"
            proj_names = ("up_proj", "down_proj")
            if hasattr(lm_model, "model") and hasattr(lm_model.model, "layers"):
                layers = lm_model.model.layers
                name_prefix = "llama_model.model.layers"
            elif hasattr(lm_model, "layers"):
                layers = lm_model.layers
                name_prefix = "llama_model.layers"
            else:
                raise ValueError("Unsupported LLaMA language model structure for M-ORE.")
        elif hasattr(self.model, "opt_model") and hasattr(self.model, "opt_proj"):
            lm_model = self.model.opt_model
            proj = self.model.opt_proj
            proj_attr = "opt_proj"
            proj_base = "opt_proj"
            group_prefix = "opt.layer"
            layer_attr = None
            proj_names = ("fc1", "fc2")
            if (
                hasattr(lm_model, "model")
                and hasattr(lm_model.model, "decoder")
                and hasattr(lm_model.model.decoder, "layers")
            ):
                layers = lm_model.model.decoder.layers
                name_prefix = "opt_model.model.decoder.layers"
            elif hasattr(lm_model, "decoder") and hasattr(lm_model.decoder, "layers"):
                layers = lm_model.decoder.layers
                name_prefix = "opt_model.decoder.layers"
            else:
                raise ValueError("Unsupported OPT language model structure for M-ORE.")
        else:
            raise ValueError(
                "M-ORE editor supports models with llama_model/llama_proj (MiniGPT4/LLaVA) "
                "or opt_model/opt_proj (BLIP-2 OPT)."
            )

        n_layers = len(layers)
        n_last = max(0, int(self.config.n_last_layers))
        start_idx = max(0, n_layers - n_last)

        for idx in range(start_idx, n_layers):
            layer = layers[idx]
            for proj_name in proj_names:
                linear = getattr(layer.mlp, proj_name) if layer_attr else getattr(layer, proj_name)
                if isinstance(linear, LoRALinear):
                    continue
                lora = LoRALinear(
                    linear,
                    rank=self.config.rank,
                    lora_alpha=self.config.lora_alpha,
                    lora_dropout=self.config.lora_dropout,
                    more_nonlinear=more_nonlinear,
                    more_layer_norm=more_layer_norm,
                    more_layer_norm_eps=more_layer_norm_eps,
                )
                lora.more_group = f"{group_prefix}.{idx}"
                if layer_attr:
                    lora.more_name = f"{name_prefix}.{idx}.mlp.{proj_name}"
                    setattr(layer.mlp, proj_name, lora)
                else:
                    lora.more_name = f"{name_prefix}.{idx}.{proj_name}"
                    setattr(layer, proj_name, lora)

        if isinstance(proj, LoRALinear):
            pass
        elif isinstance(proj, torch.nn.Linear):
            lora_proj = LoRALinear(
                proj,
                rank=self.config.rank,
                lora_alpha=self.config.lora_alpha,
                lora_dropout=self.config.lora_dropout,
                more_nonlinear=more_nonlinear,
                more_layer_norm=more_layer_norm,
                more_layer_norm_eps=more_layer_norm_eps,
            )
            lora_proj.more_group = "vision_proj"
            lora_proj.more_name = proj_base
            setattr(self.model, proj_attr, lora_proj)
        else:
            replaced = False

            def _wrap_proj(module, prefix=""):
                nonlocal replaced
                for child_name, child in module.named_children():
                    full_name = f"{prefix}.{child_name}" if prefix else child_name
                    if isinstance(child, LoRALinear):
                        continue
                    if isinstance(child, torch.nn.Linear):
                        lora_child = LoRALinear(
                            child,
                            rank=self.config.rank,
                            lora_alpha=self.config.lora_alpha,
                            lora_dropout=self.config.lora_dropout,
                            more_nonlinear=more_nonlinear,
                            more_layer_norm=more_layer_norm,
                            more_layer_norm_eps=more_layer_norm_eps,
                        )
                        lora_child.more_group = "vision_proj"
                        lora_child.more_name = f"{proj_base}.{full_name}"
                        setattr(module, child_name, lora_child)
                        replaced = True
                    else:
                        _wrap_proj(child, full_name)

            _wrap_proj(proj)
            if not replaced:
                raise ValueError(f"Unsupported {proj_base} type for M-ORE LoRA injection.")

        self.model._more_lora_injected = True

    def _collect_groups(self):
        group_map = {}
        for module in self.model.modules():
            if isinstance(module, LoRALinear) and module.more_group is not None:
                group_map.setdefault(module.more_group, []).append(module)
        return group_map

    def _freeze_non_lora(self):
        freeze_a = getattr(self.config, "more_freeze_A", True)
        for name, p in self.model.named_parameters():
            if "lora_A" in name:
                p.requires_grad = not freeze_a
            elif "lora_B" in name:
                p.requires_grad = True
            else:
                p.requires_grad = False

    def outer_parameters(self):
        return [p for p in self.parameters() if p.requires_grad]

    def _compute_target_lens(self, batch):
        if hasattr(self.model, "compute_more_target_lens"):
            lens = self.model.compute_more_target_lens(batch)
            if lens is not None:
                return lens
        text_inputs = batch.get("text_input", None)
        prompts_len = batch.get("prompts_len", None)
        if (
            isinstance(text_inputs, list)
            and isinstance(prompts_len, list)
            and len(text_inputs) == len(prompts_len)
            and hasattr(self.model, "llama_tokenizer")
        ):
            lens = []
            for text, plen in zip(text_inputs, prompts_len):
                tlen = len(self.model.llama_tokenizer.encode(text, add_special_tokens=False))
                lens.append(max(0, tlen - int(plen)))
            return lens

        labels = batch.get("labels", None)
        if isinstance(labels, torch.Tensor):
            tlen = labels.size(1)
            return [tlen for _ in range(labels.size(0))]
        return None

    def _compute_rls_delta(self, module, eta, rls_lambda):
        gA = module.lora_A.weight.grad
        gB = module.lora_B.weight.grad
        if gA is None and gB is None:
            return None, None

        z = module.more_last_z
        if z is None:
            return None, None

        P = module.more_P
        z = z.to(dtype=torch.float32)
        P = P.to(dtype=torch.float32)

        denom = rls_lambda + torch.dot(z, P @ z)
        if denom.abs().item() < 1e-12:
            return None, None

        Pz = P @ z
        P = P - torch.outer(Pz, Pz) / denom
        module.more_P.copy_(P)

        delta_A = (P @ gA.float()) * eta if gA is not None else None
        delta_B = (gB.float() @ P) * eta if gB is not None else None
        return delta_A, delta_B

    def forward(self, *inputs, **kwargs):
        if ("minigpt4" in self.config.model_name.lower()
                or "blip" in self.config.model_name.lower()
                or "llava" in self.config.model_name.lower()):
            outputs = self.model(*inputs, **kwargs)
        elif "gpt" in self.config.model_name.lower():
            outputs = _logits(self.model(input_ids=kwargs["input_ids"], attention_mask=kwargs["attention_mask"]))
        elif "llama" in self.config.model_name.lower():
            outputs = _logits(self.model(input_ids=kwargs["input_ids"], attention_mask=kwargs["attention_mask"]))
        elif "qwen" in self.config.model_name.lower():
            outputs = _logits(self.model(input_ids=kwargs["input_ids"], attention_mask=kwargs["attention_mask"]))
        else:
            outputs = _logits(self.model(**kwargs))
        return outputs

    def edit(self, batch, condition=None, detach_history=False, return_factors=False, **kwargs):
        use_masked = getattr(self.config, "more_use_masked_z", True)
        LoRALinear.more_use_masked_z = bool(use_masked)
        target_lens = self._compute_target_lens(batch)
        LoRALinear.more_target_lens = target_lens
        LoRALinear.more_batch_size = (
            len(target_lens) if isinstance(target_lens, (list, tuple)) else None
        )

        outputs = self.model(batch)
        LoRALinear.more_target_lens = None
        LoRALinear.more_batch_size = None
        LoRALinear.more_use_masked_z = True

        if not isinstance(outputs, torch.Tensor):
            logits = outputs.logits
            labels = outputs.labels
        else:
            logits = outputs
            labels = batch["labels"]

        loss = self.edit_loss_fn(self.config, logits, labels, multimodal=True)["nll"]
        loss.backward()

        score_norm = getattr(self.config, "more_score_norm", "none")
        score_norm = score_norm.lower() if isinstance(score_norm, str) else "none"
        scores = {}
        for group, modules in self.group_map.items():
            score = 0.0
            for module in modules:
                gA = module.lora_A.weight.grad
                gB = module.lora_B.weight.grad
                if gA is None and gB is None:
                    continue
                denom = 1.0
                if score_norm == "param":
                    numel = 0
                    if gA is not None:
                        numel += module.lora_A.weight.numel()
                    if gB is not None:
                        numel += module.lora_B.weight.numel()
                    denom = math.sqrt(numel) if numel > 0 else 1.0
                gA_norm = gA.float().norm(p=2).item() if gA is not None else 0.0
                gB_norm = gB.float().norm(p=2).item() if gB is not None else 0.0
                score += (gA_norm + gB_norm) / denom
            scores[group] = score

        k = min(self.config.top_k, len(scores))
        active_groups = (
            [g for g, _ in sorted(scores.items(), key=lambda x: x[1], reverse=True)[:k]]
            if k > 0
            else []
        )

        update_mode = getattr(self.config, "more_update_mode", "online").lower()
        restore_data = []
        restore_p = getattr(self.config, "more_restore_p", True)
        if update_mode == "temporary":
            for group in active_groups:
                for module in self.group_map.get(group, []):
                    restore_data.append(
                        (
                            module,
                            module.lora_A.weight.detach().clone(),
                            module.lora_B.weight.detach().clone(),
                            module.more_P.detach().clone() if restore_p else None,
                        )
                    )

        eta_vision = getattr(self.config, "more_eta_vision", None)
        rls_vision = getattr(self.config, "more_rls_lambda_vision", None)
        delta_map = {}
        for group in active_groups:
            for module in self.group_map.get(group, []):
                eta = eta_vision if (module.more_group == "vision_proj" and eta_vision is not None) else self.config.eta
                rls = (
                    rls_vision
                    if (module.more_group == "vision_proj" and rls_vision is not None)
                    else self.config.rls_lambda
                )
                delta_A, delta_B = self._compute_rls_delta(module, eta, rls)
                if delta_A is None and delta_B is None:
                    continue
                if delta_A is not None:
                    delta_map[f"{module.more_name}.lora_A.weight"] = delta_A
                if delta_B is not None:
                    delta_map[f"{module.more_name}.lora_B.weight"] = delta_B

        for group in active_groups:
            for module in self.group_map.get(group, []):
                delta_A = delta_map.get(f"{module.more_name}.lora_A.weight")
                delta_B = delta_map.get(f"{module.more_name}.lora_B.weight")
                with torch.no_grad():
                    if delta_A is not None:
                        module.lora_A.weight.add_(-delta_A.detach())
                    if delta_B is not None:
                        module.lora_B.weight.add_(-delta_B.detach())

        self.model.zero_grad(set_to_none=True)

        info_dict = {"more/top_groups": active_groups}
        if update_mode == "temporary" and restore_data:
            def _restore():
                for module, a_w, b_w, p_w in restore_data:
                    module.lora_A.weight.data.copy_(a_w)
                    module.lora_B.weight.data.copy_(b_w)
                    if p_w is not None:
                        module.more_P.data.copy_(p_w)

            info_dict["restore_fn"] = _restore

        return self, info_dict
