import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import pandas as pd
import os
from typing import List, Dict
from datetime import datetime

@torch.no_grad()
def Router_Rank_Avg(self, route_aggregation, params: List[Dict[str, torch.Tensor]]):
    subspace_rank = getattr(self, 'subspace_rank', 1)
    diff_lambda = getattr(self, 'diff_lambda', 1.)
    eps = getattr(self, 'eps', 1e-6)
    alpha = getattr(self, 'alpha', 1.0)
    log_flag = getattr(self, 'log_flag', False)
    topk = getattr(self, 'topk', 2)
    save_dir = getattr(self, 'personal_dir', "./logs")

    if not params:
        return []

    # ---- GPU & 初始化 ----
    gpu_params = [{k: v.to(self.device) for k, v in cp.items()} for cp in params]
    num_clients = len(gpu_params)
    aggregated_results = [{} for _ in range(num_clients)]
    param_names = list(gpu_params[0].keys())

    def is_lora_A0(k): return 'lora_A0' in k
    def is_lora_B0(k): return 'lora_B0' in k

    B0_keys = [k for k in param_names if is_lora_B0(k)]
    if len(B0_keys) == 0:
        return gpu_params

    # ---- FedAvg ----
    avg_B0 = {k: sum(cp[k].float() for cp in gpu_params) / num_clients for k in B0_keys}
    prev_B0 = getattr(self, "_prev_B0_cache", {k: v.clone() for k, v in avg_B0.items()})

    # ---- 收集 A0 / B0 ----
    A0_per_client = [{k: v.detach().float().clone() for k, v in cp.items() if is_lora_A0(k)} for cp in gpu_params]
    B0_per_client = [{k: v.detach().float().clone() for k, v in cp.items() if is_lora_B0(k)} for cp in gpu_params]

    def flatten_concat(tensor_dict: Dict[str, torch.Tensor]) -> torch.Tensor:
        return torch.cat([v.flatten() for v in tensor_dict.values()], dim=0)

    # ======================================================
    # 🔹 计算相似度矩阵（B0 & A0）
    # ======================================================
    B0_vecs = torch.stack([flatten_concat(bd) for bd in B0_per_client])  # [N, D]
    B0_norm = F.normalize(B0_vecs, p=2, dim=1)
    sim_B = (B0_norm @ B0_norm.T).cpu().numpy()

    A0_vecs = torch.stack([flatten_concat(ad) for ad in A0_per_client])
    A0_norm = F.normalize(A0_vecs, p=2, dim=1)
    sim_A = (A0_norm @ A0_norm.T).cpu().numpy()

    # ======================================================
    # 🧾 如果 log_flag 为 True，保存结果
    # ======================================================
    if log_flag:
        os.makedirs(save_dir, exist_ok=True)
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

        # --- 保存为 Excel ---
        df_B = pd.DataFrame(sim_B, index=[f"C{i}" for i in range(num_clients)],
                                  columns=[f"C{i}" for i in range(num_clients)])
        df_A = pd.DataFrame(sim_A, index=[f"C{i}" for i in range(num_clients)],
                                  columns=[f"C{i}" for i in range(num_clients)])

        path_B = os.path.join(save_dir, f"similarity_B0_{timestamp}.xlsx")
        path_A = os.path.join(save_dir, f"similarity_A0_{timestamp}.xlsx")
        df_B.to_excel(path_B)
        df_A.to_excel(path_A)

        # --- 保存热力图 ---
        def save_heatmap(matrix, title, filename):
            plt.figure(figsize=(6, 5))
            plt.imshow(matrix, cmap='coolwarm', interpolation='nearest')
            plt.colorbar()
            plt.title(title)
            plt.xlabel("Client")
            plt.ylabel("Client")
            for i in range(num_clients):
                for j in range(num_clients):
                    plt.text(j, i, f"{matrix[i,j]:.2f}", ha='center', va='center', color='black', fontsize=8)
            plt.tight_layout()
            plt.savefig(filename, dpi=300)
            plt.close()

        save_heatmap(sim_B, "LoRA-B0 Cosine Similarity", os.path.join(save_dir, f"similarity_B0_{timestamp}.png"))
        save_heatmap(sim_A, "LoRA-A0 Cosine Similarity", os.path.join(save_dir, f"similarity_A0_{timestamp}.png"))

        print(f"[OnlyRank_B_Avg] 相似度矩阵已保存到: {save_dir}")

    # ======================================================
    # 🔹 使用 B0 相似度选择 top-2 不相似客户端（与前逻辑一致）
    # ======================================================
    sim_matrix = torch.tensor(sim_B, device=self.device)
    def build_weak_space_for_except(i: int, topk=2):
        sim_i = sim_matrix[i].clone()
        sim_i[i] = float('inf')
        least_sim_indices = torch.topk(-sim_i, k=topk, largest=True).indices.tolist()

        aaT_sum = {}
        for j in least_sim_indices:
            A0_j = A0_per_client[j]
            if not A0_j:
                continue
            for k, A in A0_j.items():
                M = A @ A.T
                aaT_sum[k] = aaT_sum.get(k, torch.zeros_like(M)) + M / len(least_sim_indices)

        proj = {}
        for kA, M in aaT_sum.items():
            U, S, Vh = torch.linalg.svd(M, full_matrices=False)
            if subspace_rank < U.shape[1]:
                P = U[:, subspace_rank:] @ U[:, subspace_rank:].T
            else:
                P = torch.zeros_like(M)
            kB = kA.replace('lora_A0', 'lora_B0') if 'lora_A0' in kA else kA
            proj[kB] = P
        return proj

    # ---- 差异增强 + 投影 ----
    diff_B0_sum = {k: torch.zeros_like(v) for k, v in avg_B0.items()}

    for i in range(num_clients):
        B0_i = {k: gpu_params[i][k].detach().float() for k in B0_keys}
        amplified_diff = {k: (avg_B0[k] - B0_i[k]) for k in B0_keys}
        proj_dict = build_weak_space_for_except(i, topk=topk)
        for k in B0_keys:
            if k in proj_dict:
                diff_B0_sum[k] += amplified_diff[k] @ proj_dict[k]
            else:
                diff_B0_sum[k] += amplified_diff[k]

    # ---- 组合结果 ----
    final_B0 = {k: (avg_B0[k] + diff_lambda * diff_B0_sum[k]).to(dtype=params[0][k].dtype) for k in B0_keys}
    for i in range(num_clients):
        for k, v in gpu_params[i].items():
            aggregated_results[i][k] = final_B0[k] if is_lora_B0(k) else v

    self._prev_B0_cache = {k: t.clone() for k, t in final_B0.items()}
    # 下次不画图了
    self.log_flag = False
    return aggregated_results
