import numpy as np
from scipy import stats
import pickle
import time
import math
import inspect
from typing import Dict, List, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
from modeling_llama import LlamaForCausalLM as LlamaForCausalLMEdit
from captum.attr import LayerGradientXActivation
from loguru import logger
from tqdm import tqdm

from me_shared import DEVICE, MODEL_REGISTRY


def get_attr(mod: nn.Module, attrs: str):
    for attr in attrs.split("."):
        mod = getattr(mod, attr)
    return mod


def set_attr(mod: nn.Module, attrs: str, new_mod: nn.Module):
    for attr in attrs.split(".")[:-1]:
        mod = getattr(mod, attr)
    setattr(mod, attrs.split(".")[-1], new_mod)


# TODO ...
def freeze_params(model, attrs):
    param_groups = []
    for name, param in model.named_parameters():
        # logger.debug(f'{name=}')
        if attrs['ff_gate'] in name or attrs['ff_output'] in name or attrs['ff_input'] in name:
            if 'weight' in name:
                param.requires_grad = True
            else:
                param.requires_grad = False
        else:
            param.requires_grad = False
        param_groups.append({'params': [param]})
    return param_groups


# TODO formalize the computation of semantic basis based on statistic analysis (we are going to finalize our semantic things)
# based on our experiments, max-min scaling is better than mean-std scaling
def compute_sss(semantic_bases, target_label, argmax_label):
    output_target_embed = semantic_bases[target_label].detach()
    output_argmax_embed = semantic_bases[argmax_label].detach()
    semantic_steer = output_target_embed - output_argmax_embed

    # print('A')
    # print(calculate_statistics(semantic_steer))

    # # ...
    # median = torch.median(semantic_steer)
    # iqr = torch.quantile(semantic_steer, 0.75) - torch.quantile(semantic_steer, 0.25)
    # semantic_steer = (semantic_steer - median) / (iqr + 1e-8)
    # ...
    # median = torch.median(semantic_steer)
    iqr = torch.quantile(semantic_steer, 0.75) - torch.quantile(semantic_steer, 0.25)
    semantic_steer =  semantic_steer / (iqr + 1e-8)

    # print('B')
    # print(calculate_statistics(semantic_steer))

    if len(semantic_steer.shape) == 1:
        semantic_steer = semantic_steer.unsqueeze(0)
    # print(f'{scaled_semantic_steer.shape=}')
    # scaled_semantic_steer.shape=torch.Size([1024])
    # TODO we plan to no longer scale the ss, so the LR may be different...

    # print('C')
    # print(calculate_statistics(semantic_steer))

    return semantic_steer

# ...
def obtain_loss(model, input_ids, labels, is_batch=True):
    if is_batch:
        labels = labels.unsqueeze(0)
        combined_ids = torch.cat([input_ids, labels], dim=-1)
        ground_truth = combined_ids.clone()
        # Set the labels for the input part to -100 (ignored in loss computation)
        ground_truth[:, :input_ids.size(1)] = -100
        # logger.debug(f'{ground_truth=}')
        outputs = model(input_ids=combined_ids, labels=ground_truth, output_hidden_states=True, return_dict=True)
        loss = outputs.loss
    else:
        outputs = model(input_ids, output_hidden_states=True, return_dict=True)

        # cross-entropy loss
        logits = outputs.logits[:, -1, :]
        # logger.debug(f'{logits.shape=}')
        # logger.debug(f'{truth.shape=}')
        criterion = nn.CrossEntropyLoss()
        loss = criterion(logits, labels)

        # simi-based ...
        # hidden_state = outputs.hidden_states[-1]
        # embeds_hat = hidden_state[:, -1, :]
        # simis = F.cosine_similarity(embeds_hat, self.lm_head_matrix, dim=-1)
        # logits = simis.to(DEVICE)
        # target_label = ground_truth.item()
        # anchor = self.lm_head_matrix[target_label]
        # loss = 1 - F.cosine_similarity(embeds_hat, anchor, dim=-1)  #.detach()
        # loss = (torch.max(simis) - simis[target_label])  #.detach()
    return loss


# TODO currently it is not that useful...
# Example usage:
# Z = torch.randn(1, 400)
# S = torch.randn(1, 100)
# zero_rows = [0, 10, 20, 30]  # Rows to be set to zero
# W = constrained_lstsq(Z, S, zero_rows)
def constrained_lstsq(Z, S, nonzero_rows: list=None):
    # Z * W = S
    # for example: Z shape = (1, 400), W shape = (400, 100), S shape = (1, 100)
    # nonzero_rows: the row indices of W to be set to nonzero

    # Step 1: Solve unconstrained problem
    # W = torch.linalg.lstsq(Z, S).solution
    # Step 2: Set specified rows to zero
    # W[zero_rows, :] = 0
    W = torch.zeros((Z.shape[-1], S.shape[-1]))
    # Step 3: Solve for remaining rows
    Z_reduced = Z[:, nonzero_rows]
    W_reduced = torch.linalg.lstsq(Z_reduced, S).solution
    W = W.to(DEVICE)
    W_reduced = W_reduced.to(DEVICE)
    # Update W with the new solution for non-zero rows
    W[nonzero_rows, :] = W_reduced
    return W


###
def stabilize(reproducibility=True, seed=42):
    import random
    import numpy as np

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.mps.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    transformers.set_seed(seed)

    if reproducibility:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True
    else:
        torch.backends.cudnn.benchmark = True
        torch.backends.cudnn.deterministic = False


def print_info(model):
    print(model)
    for name, parameter in model.base_model.named_parameters():
        print(name, parameter.size())


def timeit(func):
    def wrapper(*args, **kwargs):
        start = time.perf_counter()
        results = func(*args, **kwargs)
        end = time.perf_counter()
        elapsed_time = format_score(end - start)
        logger.debug(f"{func.__name__} takes {elapsed_time} seconds")
        return elapsed_time, results

    return wrapper


def format_score(datum):
    return round(datum, 3)


def format_rate(datum):
    return f'{format_score(datum * 100)}%'


def format_ratio(pre_datum, post_datum):
    sign_prefix = ('+' if post_datum >= pre_datum else '')
    ratio = sign_prefix + f'{format_score((post_datum - pre_datum) * 100)}%'
    # abs_ratio = sign_prefix + f'{format_score((post_datum - pre_datum) * 100)}%'
    # rel_ratio = sign_prefix + f'{format_score((post_datum / pre_datum - 1.) * 100)}%'
    return ratio


def calculate_statistics(data):
    data = np.array(data)
    stats_dict = dict()

    stats_dict['mean'] = format_score(np.mean(data))  # 均值
    stats_dict['std'] = format_score(np.std(data))  # 标准差 (样本标准差)
    stats_dict['median'] = format_score(np.median(data))  # 中位数
    stats_dict['IQR'] = format_score(stats.iqr(data))  # 四分位距 (Q3 - Q1)
    stats_dict['MAD'] = format_score(np.mean(np.abs(data - np.mean(data))))  # 平均绝对偏差

    pos_data = data[data > 0]
    neg_data = data[data < 0]
    zero_count = np.sum(data == 0)

    stats_dict['pos_pct'] = format_rate(len(pos_data) / len(data))  # 正值百分比
    stats_dict['neg_pct'] = format_rate(len(neg_data) / len(data))  # 负值百分比
    stats_dict['zero_count'] = zero_count

    stats_dict['pos_mean'] = format_score(np.mean(pos_data)) if len(pos_data) > 0 else np.nan
    stats_dict['pos_std'] = format_score(np.std(pos_data)) if len(pos_data) > 0 else np.nan

    stats_dict['neg_mean'] = format_score(np.mean(neg_data)) if len(neg_data) > 0 else np.nan
    stats_dict['neg_std'] = format_score(np.std(neg_data)) if len(neg_data) > 0 else np.nan

    return stats_dict


# import numpy as np
#
# def compute_kernel_bias(vecs):
#     vecs = np.concatenate(vecs, axis=0)
#     mu = vecs.mean(axis=0, keepdims=True)
#     cov = np.cov(vecs.T)
#     u, s, vh = np.linalg.svd(cov)
#     W = np.dot(u, np.diag(1 / np.sqrt(s)))
#     return W, -mu
#
#
# def transform_and_normalize(vecs, kernel=None, bias=None):
#     if not (kernel is None or bias is None):
#         vecs = (vecs + bias).dot(kernel)
#     return vecs / np.sqrt(np.square(vecs).sum(axis=1, keepdims=True))


# # this is the refined numpy version
# import numpy as np
#
# def compute_kernel_bias(vecs, method="pca", k=None, eps=1e-6, dtype=np.float64):
#     """
#     vecs: (N, D) 或 [ (n1,D), (n2,D), ... ]
#     method: "pca" -> Xc @ (U * 1/sqrt(s))
#             "zca" -> Xc @ (U * 1/sqrt(s)) @ U.T
#     返回: kernel (D,D), bias = -mu (1,D)
#     """
#     X = np.concatenate(vecs, axis=0) if isinstance(vecs, (list, tuple)) else vecs
#     X = np.asarray(X, dtype=dtype)
#     mu = X.mean(axis=0, keepdims=True)
#     Xc = X - mu
#
#     # 协方差（样本协方差，除以 N-1）
#     n = max(Xc.shape[0] - 1, 1)
#     cov = (Xc.T @ Xc) / n
#
#     # 对称阵用 eigh（本征值升序）
#     s, U = np.linalg.eigh(cov)
#     # 按特征值降序（可选）
#     idx = np.argsort(s)[::-1]
#     s = s[idx]; U = U[:, idx]
#     if k is not None:
#         method="pca"
#         U = U[:, :k]; s = s[:k]
#
#     # 防止除 0/过小
#     d = 1.0 / np.sqrt(np.maximum(s, eps))
#
#     if method.lower() == "pca":
#         # 列缩放：U * d 等价于 U @ diag(d)，但更省内存
#         kernel = U * d
#     elif method.lower() == "zca":
#         kernel = (U * d) @ U.T
#     else:
#         raise ValueError("method must be 'pca' or 'zca'")
#
#     bias = -mu
#     return kernel.astype(X.dtype), bias.astype(X.dtype)
#
#
# def transform_and_normalize(vecs, kernel=None, bias=None, l2norm=False, eps=1e-12):
#     X = np.asarray(vecs)
#     if kernel is not None and bias is not None:
#         X = (X + bias) @ kernel
#     if l2norm:
#         # 逐样本 L2 归一化（可选）
#         denom = np.sqrt(np.sum(X * X, axis=1, keepdims=True))
#         X = X / np.maximum(denom, eps)
#     return X


import torch

@torch.no_grad()
def compute_kernel_bias_torch(vecs, method="pca", k=None, eps=1e-6, dtype=torch.float64):
    """
    vecs: (N, D) 或 [ (n1,D), (n2,D), ... ] 的张量/列表
    返回: kernel (D,D), bias = -mu (1,D)
    """
    X = torch.cat(vecs, dim=0) if isinstance(vecs, (list, tuple)) else vecs
    X = X.to(dtype)

    mu = X.mean(dim=0, keepdim=True)
    Xc = X - mu

    n = max(Xc.shape[0] - 1, 1)
    cov = (Xc.T @ Xc) / n  # (D, D), 对称 PSD

    # 对称阵特征分解（升序）
    s, U = torch.linalg.eigh(cov)
    # 降序（可选）
    idx = torch.argsort(s, descending=True)
    s = s[idx]; U = U[:, idx]
    if k is not None:
        method="pca"
        U = U[:, :k]; s = s[:k]

    d = torch.rsqrt(torch.clamp(s, min=eps))  # 1/sqrt(s+)
    if method.lower() == "pca":  # friendly for dimension reduction
        kernel = U * d  # 列缩放
    elif method.lower() == "zca":  # refined version of the given data
        kernel = (U * d) @ U.T
    else:
        raise ValueError("method must be 'pca' or 'zca'")

    bias = -mu
    # 按需转回 float32
    return kernel.to(torch.float32), bias.to(torch.float32)


def transform_and_normalize_torch(X, kernel=None, bias=None, l2norm=False, eps=1e-12):
    if kernel is not None and bias is not None:
        X = (X + bias) @ kernel
    if l2norm:  # be careful on whether enable L2norm, which damages the whitening effect
        denom = torch.linalg.norm(X, dim=1, keepdim=True)
        X = X / torch.clamp(denom, min=eps)
    return X


def whitening(vecs, k=None):
    kernel, bias = compute_kernel_bias_torch(vecs, k=k)
    vecs = transform_and_normalize_torch(vecs, kernel, bias)
    return vecs


# TODO comment the code lines above
class Whitener:
    def __init__(self, method="pca", k=None, l2norm=False, eps=1e-6, dtype=torch.float64):
        """
        method: "pca" or "zca"
        k: number of principal components to keep (None = keep all)
        l2norm: whether to apply L2 normalization after whitening
        """
        self.method = method.lower()
        self.k = k
        self.l2norm = l2norm
        self.eps = eps
        self.dtype = dtype

        self.kernel = None
        self.bias = None
        self.fitted = False

    def fit(self, vecs):
        """Fit whitening transform on training data"""
        X = torch.cat(vecs, dim=0) if isinstance(vecs, (list, tuple)) else vecs
        X = X.to(self.dtype)

        mu = X.mean(dim=0, keepdim=True)
        Xc = X - mu

        n = max(Xc.shape[0] - 1, 1)
        cov = (Xc.T @ Xc) / n

        # Eigen decomposition
        s, U = torch.linalg.eigh(cov)
        idx = torch.argsort(s, descending=True)
        s = s[idx]; U = U[:, idx]

        if self.k is not None:
            U = U[:, :self.k]
            s = s[:self.k]

        d = torch.rsqrt(torch.clamp(s, min=self.eps))
        if self.method == "pca":
            kernel = U * d
        elif self.method == "zca":
            kernel = (U * d) @ U.T
        else:
            raise ValueError("method must be 'pca' or 'zca'")

        self.kernel = kernel.to(torch.float32)
        self.bias = -mu.to(torch.float32)
        self.fitted = True
        return self

    def transform(self, X):
        """Apply fitted whitening transform to new data"""
        if not self.fitted:
            raise RuntimeError("Whitener must be fitted before calling transform().")
        X = (X + self.bias) @ self.kernel
        if self.l2norm:
            denom = torch.linalg.norm(X, dim=1, keepdim=True)
            X = X / torch.clamp(denom, min=self.eps)
        return X

    def fit_transform(self, vecs):
        """Fit on data and return transformed result"""
        return self.fit(vecs).transform(vecs)


# TODO ...
class PCAProjector:
    def __init__(self, k=None, scale=False, eps=1e-6, dtype=torch.float64):
        """
        PCA projector with optional scaling.
        k: number of principal components to keep (None = keep all)
        scale: if True, divide each component by sqrt(eigenvalue)
        """
        self.k = k
        self.scale = scale
        self.eps = eps
        self.dtype = dtype

        self.U = None       # principal directions
        self.s = None       # eigenvalues
        self.bias = None    # mean for centering
        self.fitted = False

    def fit(self, vecs):
        """Fit PCA projection on training data"""
        X = torch.cat(vecs, dim=0) if isinstance(vecs, (list, tuple)) else vecs
        X = X.to(self.dtype)

        mu = X.mean(dim=0, keepdim=True)
        Xc = X - mu

        n = max(Xc.shape[0] - 1, 1)
        cov = (Xc.T @ Xc) / n

        # Eigen decomposition
        s, U = torch.linalg.eigh(cov)
        idx = torch.argsort(s, descending=True)
        s = s[idx]; U = U[:, idx]

        if self.k is not None:
            U = U[:, :self.k]
            s = s[:self.k]

        self.U = U.to(torch.float32)
        self.s = s.to(torch.float32)
        self.bias = -mu.to(torch.float32)
        self.fitted = True
        return self

    def transform(self, X):
        """Apply PCA projection (with optional scaling)"""
        if not self.fitted:
            raise RuntimeError("PCAProjector must be fitted before calling transform().")
        X = (X + self.bias) @ self.U
        if self.scale:
            d = torch.rsqrt(torch.clamp(self.s, min=self.eps))  # 1/sqrt(s)
            X = X * d  # broadcast scaling along each component
        return X

    def fit_transform(self, vecs):
        """Fit on data and return projected result"""
        return self.fit(vecs).transform(vecs)


# ======================
# 归因分析
# ======================
def compute_layer_attributions(model, input_ids):
    model.eval()
    # Ensure input_ids are long integers
    input_ids = input_ids.to(torch.long)  # Convert to long type
    input_embeds = model.model.embed_tokens(input_ids)

    layer_attributions = []

    for layer_idx in tqdm(range(len(model.model.layers)), desc="Analyzing layers"):
        layer = model.model.layers[layer_idx]

        attributor = LayerGradientXActivation(model, layer)

        attr = attributor.attribute(
            inputs=input_embeds,
            additional_forward_args=(),
            target=0
        )

        layer_score = attr.abs().sum().item()
        layer_attributions.append(layer_score)

    return layer_attributions


def find_most_attributed_layer(model, sample_text, tokenizer):
    # Ensure the tokenizer output is converted to long type
    input_ids = tokenizer(
        sample_text,
        return_tensors="pt",
        # max_length=INPUT_LENGTH,
        # truncation=True
    ).input_ids.to(DEVICE).long()  # Explicitly convert to long

    layer_attributions = compute_layer_attributions(model, input_ids)
    return torch.argmax(torch.tensor(layer_attributions)).item()



# ....
# use whitening, etc
def align(X_highdim, Y_target, embeds_hat, reducer):
    # Dimensionality reduction utilities
    # Compare methods (the n_components requires a range)

    # print(f"Running {name}...")
    start_time = time.time()

    # # ... old
    # X_reduced = reducer.fit_transform(X_highdim)
    # ... new
    X_reduced = reducer.fit_transform(X_highdim)
    embeds_hat_reduced = reducer.transform(embeds_hat)  # Assumes same feature space!
    elapsed = time.time() - start_time
    print(f"Time = {elapsed:.2f}s")
    return X_reduced, embeds_hat_reduced

def align2(embeds_hat, src_bases, tgt_bases):
    if len(embeds_hat.shape) == 3:
        embeds_hat = embeds_hat.squeeze(0)
        logger.debug(f'{embeds_hat.shape=}')  # X_new.shape=(151936, 2048)

    # TODO validate whether the batched mode works well
    embeds_hat = F.normalize(embeds_hat, dim=-1)  # (166, 1024)  # this is optional, seem not necessary...
    src_bases = F.normalize(src_bases, dim=-1)  # (151936, 1024)  # this is required, or else the performance will drop
    simis = embeds_hat @ src_bases.T
    # simis = F.cosine_similarity(embeds_hat, src_bases)  # another equivalent implementation
    # logger.warning(f'{format_score(torch.sum(simis).item())=}, {format_score(simi.item())=}')
    # simis = F.softmax(simis).unsqueeze(0)
    # simis = F.softmax(simis)   # no reason to enable softmax (...)

    # logger.debug(f'{layer_anchors.shape=}')
    # logger.debug(f'{simis.shape=}')
    # # (166, 151936) @ (151936, 1024) → (166, 1024)
    embeds_hat_reduced = simis @ tgt_bases
    # logger.debug(f'{semantic_parts.shape=}')
    # logger.warning(f'{torch.sum(simis)=}')

    # semantic_combined = torch.sum(semantic_parts, dim=0)
    # # logger.debug(f'{semantic_combined.shape=}')
    # simi = F.cosine_similarity(embeds_hat, semantic_combined, dim=-1).item()



    logger.debug(f'{embeds_hat_reduced.shape=}')  # X_new_reduced.shape=(151936, 1024)

    return embeds_hat_reduced
