from copy import deepcopy
from typing import Any, Dict, List, Tuple

import math
import torch

from .lora_layers import LoRALinear
from .more_hparams import MOREMultimodalHyperParams


def _move_to_device(batch, device):
    moved = {}
    for k, v in batch.items():
        if isinstance(v, torch.Tensor):
            moved[k] = v.to(device)
        else:
            moved[k] = v
    return moved


class MOREMultimodal:
    def __init__(self, model, config: MOREMultimodalHyperParams, device: str):
        self.model = model
        self.config = config
        self.device = device
        self._inject_lora()
        self.group_map = self._collect_groups()
        self._freeze_non_lora()

    def _inject_lora(self):
        if getattr(self.model, "_more_lora_injected", False):
            return

        more_nonlinear = getattr(self.config, "more_nonlinear", "none")
        more_layer_norm = getattr(self.config, "more_layer_norm", False)
        more_layer_norm_eps = getattr(self.config, "more_layer_norm_eps", 1e-5)

        if hasattr(self.model, "llama_model") and hasattr(self.model, "llama_proj"):
            lm_model = self.model.llama_model
            proj = self.model.llama_proj
            proj_attr = "llama_proj"
            proj_base = "llama_proj"
            group_prefix = "llama.layer"
            layer_attr = "mlp"
            proj_names = ("up_proj", "down_proj")
            if hasattr(lm_model, "model") and hasattr(lm_model.model, "layers"):
                layers = lm_model.model.layers
                name_prefix = "llama_model.model.layers"
            elif hasattr(lm_model, "layers"):
                layers = lm_model.layers
                name_prefix = "llama_model.layers"
            else:
                raise ValueError("Unsupported LLaMA language model structure for M-ORE.")
        elif hasattr(self.model, "opt_model") and hasattr(self.model, "opt_proj"):
            lm_model = self.model.opt_model
            proj = self.model.opt_proj
            proj_attr = "opt_proj"
            proj_base = "opt_proj"
            group_prefix = "opt.layer"
            layer_attr = None
            proj_names = ("fc1", "fc2")
            if (
                hasattr(lm_model, "model")
                and hasattr(lm_model.model, "decoder")
                and hasattr(lm_model.model.decoder, "layers")
            ):
                layers = lm_model.model.decoder.layers
                name_prefix = "opt_model.model.decoder.layers"
            elif hasattr(lm_model, "decoder") and hasattr(lm_model.decoder, "layers"):
                layers = lm_model.decoder.layers
                name_prefix = "opt_model.decoder.layers"
            else:
                raise ValueError("Unsupported OPT language model structure for M-ORE.")
        else:
            raise ValueError(
                "M-ORE only supports models with llama_model/llama_proj (MiniGPT4/LLaVA) "
                "or opt_model/opt_proj (BLIP-2 OPT)."
            )

        n_layers = len(layers)
        start_idx = max(0, n_layers - self.config.n_last_layers)

        for idx in range(start_idx, n_layers):
            layer = layers[idx]
            for proj_name in proj_names:
                linear = getattr(layer.mlp, proj_name) if layer_attr else getattr(layer, proj_name)
                lora = LoRALinear(
                    linear,
                    rank=self.config.rank,
                    lora_alpha=self.config.lora_alpha,
                    lora_dropout=self.config.lora_dropout,
                    more_nonlinear=more_nonlinear,
                    more_layer_norm=more_layer_norm,
                    more_layer_norm_eps=more_layer_norm_eps,
                )
                lora.more_group = f"{group_prefix}.{idx}"
                if layer_attr:
                    lora.more_name = f"{name_prefix}.{idx}.mlp.{proj_name}"
                    setattr(layer.mlp, proj_name, lora)
                else:
                    lora.more_name = f"{name_prefix}.{idx}.{proj_name}"
                    setattr(layer, proj_name, lora)

        if isinstance(proj, LoRALinear):
            pass
        elif isinstance(proj, torch.nn.Linear):
            lora_proj = LoRALinear(
                proj,
                rank=self.config.rank,
                lora_alpha=self.config.lora_alpha,
                lora_dropout=self.config.lora_dropout,
                more_nonlinear=more_nonlinear,
                more_layer_norm=more_layer_norm,
                more_layer_norm_eps=more_layer_norm_eps,
            )
            lora_proj.more_group = "vision_proj"
            lora_proj.more_name = proj_base
            setattr(self.model, proj_attr, lora_proj)
        else:
            replaced = False

            def _wrap_proj(module, prefix=""):
                nonlocal replaced
                for child_name, child in module.named_children():
                    full_name = f"{prefix}.{child_name}" if prefix else child_name
                    if isinstance(child, LoRALinear):
                        continue
                    if isinstance(child, torch.nn.Linear):
                        lora_child = LoRALinear(
                            child,
                            rank=self.config.rank,
                            lora_alpha=self.config.lora_alpha,
                            lora_dropout=self.config.lora_dropout,
                            more_nonlinear=more_nonlinear,
                            more_layer_norm=more_layer_norm,
                            more_layer_norm_eps=more_layer_norm_eps,
                        )
                        lora_child.more_group = "vision_proj"
                        lora_child.more_name = f"{proj_base}.{full_name}"
                        setattr(module, child_name, lora_child)
                        replaced = True
                    else:
                        _wrap_proj(child, full_name)

            _wrap_proj(proj)
            if not replaced:
                raise ValueError(f"Unsupported {proj_base} type for M-ORE LoRA injection.")

        self.model._more_lora_injected = True

    def _collect_groups(self):
        group_map = {}
        for module in self.model.modules():
            if isinstance(module, LoRALinear) and module.more_group is not None:
                group_map.setdefault(module.more_group, []).append(module)
        return group_map

    def _freeze_non_lora(self):
        freeze_a = getattr(self.config, "more_freeze_A", True)
        for name, p in self.model.named_parameters():
            if "lora_A" in name:
                p.requires_grad = not freeze_a
            elif "lora_B" in name:
                p.requires_grad = True
            else:
                p.requires_grad = False

    def _build_batch(self, request, tok):
        prompt = request["prompt"]
        target = request["target"]
        if target and not target.startswith(" "):
            target = " " + target

        text_input = [prompt + target]
        prompts_len = [len(tok.encode(prompt, add_special_tokens=False))]
        text_len = [len(tok.encode(prompt + target, add_special_tokens=False))]
        target_lens = [max(0, tlen - plen) for tlen, plen in zip(text_len, prompts_len)]
        labels = tok([target], add_special_tokens=False, return_tensors="pt")["input_ids"]

        image = request.get("image", None)
        if image is not None and image.dim() == 3:
            image = image.unsqueeze(0)

        batch = {
            "text_input": text_input,
            "labels": labels,
            "prompts_len": prompts_len,
            "image": image,
        }
        return _move_to_device(batch, self.device), target_lens

    def _top_groups(self, scores):
        if not scores:
            return []
        k = min(self.config.top_k, len(scores))
        ranked = sorted(scores.items(), key=lambda x: x[1], reverse=True)
        return [g for g, _ in ranked[:k]]

    def _compute_rls_delta(self, module, eta, rls_lambda):
        gA = module.lora_A.weight.grad
        gB = module.lora_B.weight.grad
        if gA is None and gB is None:
            return None, None

        z = module.more_last_z
        if z is None:
            return None, None

        P = module.more_P
        z = z.to(dtype=torch.float32)
        P = P.to(dtype=torch.float32)

        denom = rls_lambda + torch.dot(z, P @ z)
        if denom.abs().item() < 1e-12:
            return None, None

        Pz = P @ z
        P = P - torch.outer(Pz, Pz) / denom
        module.more_P.copy_(P)

        delta_A = (P @ gA.float()) * eta if gA is not None else None
        delta_B = (gB.float() @ P) * eta if gB is not None else None
        return delta_A, delta_B

    def edit(self, request, tok):
        self.model.train()
        self.model.zero_grad(set_to_none=True)

        batch, target_lens = self._build_batch(request, tok)
        use_masked = request.get("more_masked_z", True)
        LoRALinear.more_use_masked_z = bool(use_masked)
        LoRALinear.more_target_lens = target_lens
        LoRALinear.more_batch_size = (
            len(target_lens) if isinstance(target_lens, (list, tuple)) else None
        )
        outputs = self.model(batch)
        LoRALinear.more_target_lens = None
        LoRALinear.more_batch_size = None
        LoRALinear.more_use_masked_z = True
        loss = outputs.loss if hasattr(outputs, "loss") else outputs[0]
        loss.backward()

        score_norm = getattr(self.config, "more_score_norm", "none")
        score_norm = score_norm.lower() if isinstance(score_norm, str) else "none"
        scores = {}
        for group, modules in self.group_map.items():
            score = 0.0
            for module in modules:
                gA = module.lora_A.weight.grad
                gB = module.lora_B.weight.grad
                if gA is None and gB is None:
                    continue
                denom = 1.0
                if score_norm == "param":
                    numel = 0
                    if gA is not None:
                        numel += module.lora_A.weight.numel()
                    if gB is not None:
                        numel += module.lora_B.weight.numel()
                    denom = math.sqrt(numel) if numel > 0 else 1.0
                gA_norm = gA.float().norm(p=2).item() if gA is not None else 0.0
                gB_norm = gB.float().norm(p=2).item() if gB is not None else 0.0
                score += (gA_norm + gB_norm) / denom
            scores[group] = score

        active_groups = self._top_groups(scores)
        if active_groups:
            topk_msg = ", ".join([f"{g}={scores[g]:.4f}" for g in active_groups])
            print(f"[M-ORE] Top-k groups: {topk_msg}")

        eta_vision = getattr(self.config, "more_eta_vision", None)
        rls_vision = getattr(self.config, "more_rls_lambda_vision", None)
        for group in active_groups:
            for module in self.group_map.get(group, []):
                eta = eta_vision if (module.more_group == "vision_proj" and eta_vision is not None) else self.config.eta
                rls = (
                    rls_vision
                    if (module.more_group == "vision_proj" and rls_vision is not None)
                    else self.config.rls_lambda
                )
                delta_A, delta_B = self._compute_rls_delta(module, eta, rls)
                if delta_A is None and delta_B is None:
                    continue
                with torch.no_grad():
                    if delta_A is not None:
                        module.lora_A.weight.add_(-delta_A.detach())
                    if delta_B is not None:
                        module.lora_B.weight.add_(-delta_B.detach())

        self.model.zero_grad(set_to_none=True)
        self.model.eval()


def apply_more_to_multimodal_model(
    model,
    tok,
    requests: List[Dict],
    hparams: MOREMultimodalHyperParams,
    copy=False,
    return_orig_weights=False,
    keep_original_weight=False,
    **kwargs: Any,
) -> Tuple[torch.nn.Module, Dict[str, Any]]:
    device = f"cuda:{hparams.device}"
    if copy:
        model = deepcopy(model)
        model.to(device)

    if not hasattr(model, "_more_editor"):
        model._more_editor = MOREMultimodal(model=model, config=hparams, device=device)

    if isinstance(requests, dict):
        requests = [requests]

    for request in requests:
        model._more_editor.edit(request, tok)

    return model, {}
