from functools import partial
from typing import Dict, List, Tuple, Optional

import torch
import torch.nn as nn
from loguru import logger

from me_shared import DEVICE

# ----------------------------
# Config
# ----------------------------
# Loss & masking options
MASKING_ENABLED = False         # set False to disable gradient masking

# Which leaf linears to optimize inside each block
OPTIMIZE = {
    "attn.q_proj": False,  # True,
    "attn.k_proj": False,
    "attn.v_proj": False,
    "attn.o_proj": False,  # True,
    "mlp.up_proj": True,
    "mlp.gate_proj": True,
    "mlp.down_proj": True,
}


#######################


import torch
import torch.nn.functional as F

# ----------------------------
# Utilities
# ----------------------------

def get_blocks(model) -> List[nn.Module]:
    return list(model.model.layers)


def get_attn(block):
    return getattr(block, "self_attn", getattr(block, "attn", None))


def get_mlp(block):
    return getattr(block, "mlp", None)


def list_target_linears(block) -> List[Tuple[str, nn.Linear]]:
    attn = get_attn(block); mlp = get_mlp(block)
    pairs = []
    if attn is not None:
        for key in ["q_proj", "k_proj", "v_proj", "o_proj"]:
            m = getattr(attn, key, None)
            if isinstance(m, nn.Linear):
                pairs.append((f"attn.{key}", m))
    if mlp is not None:
        for key in ["up_proj", "gate_proj", "down_proj"]:
            m = getattr(mlp, key, None)
            if isinstance(m, nn.Linear):
                pairs.append((f"mlp.{key}", m))
    return pairs


# --- HF-style input prep -----------------------------------------------------

def make_position_ids(attention_mask: torch.Tensor, B: int, L: int):
    if attention_mask is not None:
        am = attention_mask.to(device=DEVICE, dtype=torch.long)
        return (am.cumsum(dim=-1) - 1).clamp_(min=0)
    else:
        return torch.arange(L, device=DEVICE).unsqueeze(0).expand(B, -1)


def to_causal_4d(attention_mask: torch.Tensor, B: int, L: int, dtype):
    """Build a 4D additive mask [B, 1, L, L] like HF's create_causal_mask.
    - Triangular: disallow attending to future tokens (fill -inf above diagonal).
    - Padding: if attention_mask provided (0 for pad), disallow attending to pad keys.
    """
    # Base causal mask
    causal = torch.full((L, L), float("-inf"), device=DEVICE)
    causal = torch.triu(causal, diagonal=1)  # [-inf] above diagonal, 0 on & below
    mask = causal.view(1, 1, L, L).expand(B, 1, L, L).clone()

    if attention_mask is not None:
        key_pad = (attention_mask == 0).to(torch.bool).view(B, 1, 1, L)  # broadcast onto key dimension
        mask = mask.masked_fill(key_pad, float("-inf"))

    return mask.to(dtype)


def hf_prepare_inputs(model, batch, h_in: torch.Tensor):
    """Prepare (attention_mask_prepared, position_ids) like HF.
    - Prefer decoder helper if available (e.g., LLaMA has `_prepare_decoder_attention_mask`).
    - Otherwise, fallback to our 4D additive causal mask.
    """
    dec = model.model
    B, L, _ = h_in.shape
    raw_am = batch.get("attention_mask", None)

    position_ids = make_position_ids(raw_am, B, L)

    if hasattr(dec, "_prepare_decoder_attention_mask"):
        attn_mask_prepared = dec._prepare_decoder_attention_mask(raw_am, (B, L), h_in, past_key_values_length=0)
    else:
        dtype = h_in.dtype if torch.is_floating_point(h_in) else torch.float32
        attn_mask_prepared = to_causal_4d(raw_am.to(DEVICE) if raw_am is not None else None, B, L, dtype)

    return attn_mask_prepared, position_ids


def build_position_embeddings(model, hidden_states: torch.Tensor, position_ids: torch.Tensor):
    """Create (cos, sin) once per batch using the model's own RoPE module, like HF does.
    Tries `model.model.rotary_emb(...)` first; falls back to the first layer's attention.rotary_emb.
    """
    dec = model.model
    if hasattr(dec, "rotary_emb") and dec.rotary_emb is not None:
        return dec.rotary_emb(hidden_states, position_ids)
    # Legacy fallback (older snapshots): use the first attention module's rotary_emb
    attn0 = get_attn(get_blocks(model)[0])
    if hasattr(attn0, "rotary_emb") and attn0.rotary_emb is not None:
        return attn0.rotary_emb(hidden_states, position_ids)
    raise RuntimeError("No rotary_emb found on model or first attention block; cannot build position embeddings.")


# def inner_optimize_block(block,
#                          h_in: torch.Tensor,
#                          tt: torch.Tensor,
#                          a_outputs,
#                          b_outputs,
#                          src_lm_head,
#                          tgt_lm_head,
#                          target_labels,
#                          truncation_lens: int,
#                          kwargs_blk=None):
#     from torch.func import functional_call
#
#     # 输入可导
#     hh = h_in.clone().requires_grad_(True)
#     # overrides = {n: p.detach() for n, p in block.named_parameters()}
#     # yy = functional_call(block, overrides, (hh,), kwargs=kwargs_blk)
#     yy = block(hh, **kwargs_blk)
#     yy = yy[0] if isinstance(yy, (tuple, list)) else yy
#     yy = yy.squeeze(0)
#     # yy = yy[:, -truncation_lens:, :]
#     yy = yy[-truncation_lens:, :]
#     # tt = tt[:, -truncation_lens:, :]
#     tt = tt[-truncation_lens:, :]
#
#     a_outputs = a_outputs.squeeze(0)
#     b_outputs = b_outputs.squeeze(0)
#     a_outputs = a_outputs[-truncation_lens:, :]
#     b_outputs = b_outputs[-truncation_lens:, :]
#
#     # # TODO try it for no latent-space mapping
#     # cka_loss = 1 - linear_cka(yy, tt)
#     # logger.warning(f'AAA {cka_loss.detach().item()=}')
#     # cka_loss.backward()
#
#     # cosine loss # TODO
#     # 计算 loss 并 backward
#     # local_loss = 1 - F.cosine_similarity(yy, tt, dim=-1).mean()  # a better choice for optimization on multiple matrices
#     # logger.warning(f'AAA {local_loss.detach().item()=}')
#     # local_loss.backward()
#
#     # KL loss
#     # # Step 1: 变成概率分布（dim=-1 表示按最后一个维度做 softmax）
#     # p_student = F.log_softmax(yy @ lm_head_matrix, dim=-1)  # 学生的 log 概率
#     # p_teacher = F.softmax(tt @ lm_head_matrix, dim=-1)  # 教师的概率
#     # # Step 2: KL 散度
#     # kl_loss = F.kl_div(p_student, p_teacher, reduction="batchmean")
#     # logger.warning(f'AAA {kl_loss.detach().item()=}')
#     # kl_loss.backward()
#
#     # # ... soft label
#     # p_teacher = F.softmax(tt, dim=-1)  # teacher 分布
#     # log_p_student = F.log_softmax(yy, dim=-1)
#     # ce_loss = -(p_teacher * log_p_student).sum(dim=-1).mean()
#     # logger.warning(f'AAA {ce_loss.detach().item()=}')
#     # ce_loss.backward()
#
#     # ...traditional
#     tt = a_outputs @ src_lm_head
#     zz = b_outputs @ tgt_lm_head
#     # tt = tt @ src_lm_head
#     yy = yy @ tgt_lm_head
#     labels = tt.argmax(dim=-1)  # (66,) 伪标签
#
#     logger.debug(f'{yy.shape=} {tt.shape=} {labels.shape=}')
#     # ce_loss = F.cross_entropy(yy, labels)
#     ce_loss = F.cross_entropy(yy, torch.tensor(target_labels, dtype=torch.long))
#     logger.warning(f'AAA {ce_loss.detach().item()=}')
#     ce_loss.backward()
#
#     ce_loss = F.cross_entropy(zz, labels)
#     logger.warning(f'BBB {ce_loss.detach().item()=}')
#
#     # 直接用 SGD 更新需要优化的 block 参数
#     # pick modules to adapt
#     name2mod = {name: mod for name, mod in list_target_linears(block) if OPTIMIZE.get(name, False)}
#     params_to_update = [getattr(mod, 'weight') for mod in name2mod.values() if hasattr(mod, 'weight')]
#     if params_to_update:
#         # TODO how to find a proper LR ... (xxx)
#         # optimizer = torch.optim.SGD(params_to_update, lr=5e-2)
#         optimizer = torch.optim.SGD(params_to_update, lr=1e-2)
#         optimizer.step()
#         optimizer.zero_grad()


# this is a clean one (to debug)!!!
def inner_optimize_block(block,
                         h_in: torch.Tensor,
                         tt: torch.Tensor,
                         a_outputs,
                         b_outputs,
                         src_lm_head,
                         tgt_lm_head,
                         target_labels,
                         truncation_lens: int,
                         kwargs_blk=None):
    from torch.func import functional_call

    # 输入可导
    hh = h_in.clone().requires_grad_(True)
    # overrides = {n: p.detach() for n, p in block.named_parameters()}
    # yy = functional_call(block, overrides, (hh,), kwargs=kwargs_blk)
    yy = block(hh, **kwargs_blk)
    yy = yy[0] if isinstance(yy, (tuple, list)) else yy
    yy = yy.squeeze(0)
    # yy = yy[:, -truncation_lens:, :]
    yy = yy[-truncation_lens:, :]
    # yy = yy @ tgt_lm_head

    local_loss = 1 - F.cosine_similarity(yy, tt, dim=-1).mean()  # a better choice for optimization on multiple matrices
    logger.warning(f'AAA {local_loss.detach().item()=}')
    local_loss.backward()

    # ce_loss = F.cross_entropy(yy, labels)
    # # ce_loss = F.cross_entropy(yy, torch.tensor(target_labels, dtype=torch.long))
    # logger.warning(f'AAA {ce_loss.detach().item()=}')
    # ce_loss.backward()
    return yy


# for a batched case, compute the loss and compute the mean, or do it in tensor comp
def linear_cka(yy, tt):
    logger.debug(f'{yy.shape=}, {tt.shape=}')
    """
    yy, tt: [hidden_dim, seq_len] 或 [seq_len, hidden_dim] 均可
    返回: CKA 相似度标量
    """
    # 如果 hidden_dim 在前，seq_len 在后，需要转置
    if yy.shape[0] > yy.shape[1]:
        yy = yy.T
    if tt.shape[0] > tt.shape[1]:
        tt = tt.T

    # 中心化
    X = yy - yy.mean(dim=0, keepdim=True)
    Y = tt - tt.mean(dim=0, keepdim=True)

    # Gram 矩阵
    K_X = X @ X.T  # shape [seq_len, seq_len]
    K_Y = Y @ Y.T

    # HSIC
    hsic = (K_X * K_Y).sum()
    norm_X = (K_X * K_X).sum().sqrt()
    norm_Y = (K_Y * K_Y).sum().sqrt()

    cka_score = hsic / (norm_X * norm_Y + 1e-8)
    return cka_score


def obtain_kwargs_blk(model, batch):
    with torch.no_grad():
        h_embed = model.model.embed_tokens(batch["input_ids"])  # [B,L,H]
        attn_mask_prepared, position_ids = hf_prepare_inputs(model, batch, h_embed)
        position_embeddings = build_position_embeddings(model, h_embed, position_ids)
        B, L, _ = h_embed.shape
        assert position_ids.shape == (B, L)

    # ...
    if position_ids is None or position_embeddings is None:
        raise ValueError("inner_optimize_block requires position_ids and position_embeddings")
    # Pre-bind kwargs for clarity (we'll use stateless.functional_call to inject params)
    kwargs_blk = dict(
        attention_mask=attn_mask_prepared,
        position_ids=position_ids,
        past_key_value=None,
        output_attentions=False,
        use_cache=False,
        position_embeddings=position_embeddings,
    )
    return kwargs_blk

# ----------------------------
# SFT batching (prompt+completion) — simple list[(prompt, completion)]
# ----------------------------

def make_sft_batch(tokenizer,
                   pairs: List[Tuple[str, str]],
                   device,
                   max_length: Optional[int] = None):
    """Build an SFT batch from (prompt, completion) pairs with a **consistent tensor format**.

    Behavior:
    - If `max_length` is None and **len(pairs) == 1**: no truncation & no padding.
      Return tensors with batch dim 1 (shape [1, L]).
    - If `max_length` is None and **len(pairs) > 1**: no truncation; **dynamically pad**
      all samples to the batch's max sequence length.
    - If `max_length` is provided: truncate to this length and **pad to exactly this length**.

    Always returns a dict of **tensors**: {"input_ids", "attention_mask", "labels"}.
    """
    input_ids_list, labels_list, attn_list = [], [], []
    eos_id = tokenizer.eos_token_id
    pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else eos_id

    for prompt, completion in pairs:
        if max_length is None:
            enc_p = tokenizer(prompt, add_special_tokens=False)
            enc_c = tokenizer(completion, add_special_tokens=False)
        else:
            enc_p = tokenizer(prompt, add_special_tokens=False, truncation=True, max_length=max_length)
            enc_c = tokenizer(completion, add_special_tokens=False, truncation=True, max_length=max_length)
        ids_p = enc_p["input_ids"]
        ids_c = enc_c["input_ids"]  # + ([eos_id] if eos_id is not None else [])
        # logger.info(f'{len(ids_p)=}, {len(ids_c)=}')
        ids = ids_p + ids_c
        if max_length is not None and len(ids) > max_length:
            ids = ids[-max_length:]
            prompt_len = max(0, len(ids) - len(ids_c))
        else:
            prompt_len = len(ids_p)
        labels = [-100] * prompt_len + ids[prompt_len:]
        attn = [1] * len(ids)
        input_ids_list.append(torch.tensor(ids, dtype=torch.long))
        labels_list.append(torch.tensor(labels, dtype=torch.long))
        attn_list.append(torch.tensor(attn, dtype=torch.long))

    def pad_to_len(t, value, L):
        pad = L - t.size(0)
        if pad <= 0:
            return t[:L]
        return F.pad(t, (0, pad), value=value)

    B = len(input_ids_list)

    # Case A: max_length is None and B == 1  → no pad, keep exact length, but add batch dim
    if max_length is None and B == 1:
        input_ids = input_ids_list[0].unsqueeze(0).to(device)
        attention_mask = attn_list[0].unsqueeze(0).to(device)
        labels = labels_list[0].unsqueeze(0).to(device)
        return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}

    # Case B: max_length is None and B > 1  → dynamic pad to batch max length
    if max_length is None:
        maxlen = max(t.size(0) for t in input_ids_list)
        input_ids = torch.stack([pad_to_len(t, pad_id, maxlen) for t in input_ids_list], 0).to(device)
        labels = torch.stack([pad_to_len(t, -100, maxlen) for t in labels_list], 0).to(device)
        attention_mask = torch.stack([pad_to_len(t, 0, maxlen) for t in attn_list], 0).to(device)
        return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}

    # Case C: max_length is provided  → truncate & pad to exactly max_length
    input_ids = torch.stack([pad_to_len(t, pad_id, max_length) for t in input_ids_list], 0).to(device)
    labels = torch.stack([pad_to_len(t, -100, max_length) for t in labels_list], 0).to(device)
    attention_mask = torch.stack([pad_to_len(t, 0, max_length) for t in attn_list], 0).to(device)
    return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
