# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, List, Tuple
import matplotlib.pyplot as plt
import seaborn as sns
import os.path as osp
import os
import numpy as np
from .utils import compute_prompt_weights, modulate_prompts, log_sim_weight

class MultiPromptPool(nn.Module):
    def __init__(
        self,
        num_pools: int = 5,
        M: int = 32,
        Lp: int = 4,
        D: int = 768,          
        Dk: int = None,       
        init_scale: float = 0.02,
        pool_names: List[str] = None,
        device: torch.device = torch.device("cuda")
    ):
        super().__init__()
        self.num_pools = num_pools
        self.M = M
        self.Lp = Lp
        self.D = D
        self.Dk = D if Dk is None else Dk
        self.pool_names = pool_names if pool_names is not None else POOL_NAMES[:num_pools]
        self.max_tasks_count = 11

        # keys: [num_pools, M, Dk]
        self.keys = nn.Parameter(torch.randn(num_pools, M, self.Dk, device=device) * init_scale)

        # values: [num_pools, M, Lp, D]
        # self.values = nn.Parameter(torch.randn(num_pools, M, Lp, D, device=device) * init_scale)
        self.values = nn.Parameter(torch.zeros(num_pools, M, Lp, D, device=device))
        self.task_values_log = []  
        self.task_ids = [] 

        self.task_values_dict: Dict[int, torch.Tensor] = {}
        self.task_keys_dict: Dict[int, torch.Tensor] = {}

        self.q_proj = None
        if self.Dk != D:
            self.q_proj = nn.Linear(D, self.Dk, bias=False).to(device)

        self.register_buffer(
            "select_counts",
            torch.zeros(self.max_tasks_count, self.num_pools, M, dtype=torch.float32, device=device)
        )

        self.register_buffer(
            "task_select_counts",
            torch.zeros(self.max_tasks_count, self.max_tasks_count, num_pools, self.max_tasks_count * M, dtype=torch.float32, device=device)
        )

        self.register_buffer(
            "sim_weight_sums",
            torch.zeros(self.max_tasks_count, self.max_tasks_count, num_pools, dtype=torch.float32, device=device)
        )
        self.register_buffer(
            "sim_weight_counts",
            torch.zeros(self.max_tasks_count, self.max_tasks_count, dtype=torch.int32, device=device)
        )

        self.register_buffer("diversity_log", torch.zeros(0))   # 存储每个任务的 diversity

        self.register_buffer(
            "task_select_sims",
            torch.zeros(self.max_tasks_count, self.max_tasks_count, num_pools, self.max_tasks_count * M, dtype=torch.float32, device=device)
        )
        self.task_ids = []  

    def save_task_values(self, task_id: int):
        self.task_values_dict[task_id] = self.values.detach().clone()
        self.task_keys_dict[task_id] = self.keys.detach().clone()

    def load_task_values(self, task_id: int):
        if task_id not in self.task_values_dict or task_id not in self.task_keys_dict:
            raise ValueError(f"Task {task_id} not found in task_values_dict or task_keys_dict")
        self.values.data = self.task_values_dict[task_id].clone()
        self.keys.data = self.task_keys_dict[task_id].clone()

    def fuse_values(
        self,
        task_id_prev: int,
        task_id_curr: int,
        alpha: float = 0.5,
        strategy: str = 'inherit',
        fuse_keys: bool = True
    ):

        V_prev = self.task_values_dict[task_id_prev]
        V_curr = self.task_values_dict[task_id_curr]

        self.values.data = alpha * V_prev + (1 - alpha) * V_curr

        if fuse_keys:
            K_prev = self.task_keys_dict[task_id_prev]
            K_curr = self.task_keys_dict[task_id_curr]
            self.keys.data = alpha * K_prev + (1 - alpha) * K_curr

    def _project_q(self, q: torch.Tensor) -> torch.Tensor:
        """
        q: [B, D] -> [B, Dk]
        """
        if self.q_proj is not None:
            return self.q_proj(q)
        return q

    def topk_select(
        self,
        q: torch.Tensor,
        topk: int = 2,
        cur_train_task_id: int = -1
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        B = q.size(0)
        qk = self._project_q(q)                         
        qk = F.normalize(qk, dim=-1)                  

        k = F.normalize(self.keys, dim=-1)            

        dists = []
        sims = []
        for p in range(self.num_pools):
            kp = k[p].to(qk.dtype)                       
            sim = qk @ kp.t()                            
            dist = 1 - sim                             

            freq_norm = self.get_select_freq(task_id=cur_train_task_id)[p] # [M]

            if torch.rand(1).item() < 0.01: 
                print(f"[Pool {p}] freq_norm: min={freq_norm.min().item():.4f}, "
                    f"max={freq_norm.max().item():.4f}, mean={freq_norm.mean().item():.4f}")
                print(f"[Pool {p}] dist before: {dist[0,:5].detach().cpu().numpy()}, "
                    f"after: {(dist*freq_norm.unsqueeze(0))[0,:5].detach().cpu().numpy()}")

            dist = dist * freq_norm.unsqueeze(0) 

            dists.append(dist)
            sims.append(sim)

        dists = torch.stack(dists, dim=1)                # [B, num_pools, M]
        sims = torch.stack(sims, dim=1)                  # [B, num_pools, M]

        # topk
        dist_sel, idx_sel = torch.topk(dists, k=topk, dim=-1, largest=False)

        sim_sel = torch.gather(sims, dim=-1, index=idx_sel)   # [B, num_pools, topk]

        sim_weight = compute_prompt_weights(sim_sel, method="raw", threshold=0.01)

        prompt_loss = -sim_sel.mean()
        
        return idx_sel, prompt_loss, sim_weight

    def gather_values(
        self,
        idx_sel: torch.Tensor
    ) -> torch.Tensor:

        B, num_pools, topk = idx_sel.shape
        assert num_pools == self.num_pools

        # values: [num_pools, M, Lp, D]
        v = self.values

        out_list = []
        for b in range(B):
            seq_list_b = []
            for p in range(num_pools):
                for t in range(topk):
                    m_idx = idx_sel[b, p, t].item()
                    vp = v[p, m_idx]          # [Lp, D]
                    seq_list_b.append(vp)
            if len(seq_list_b) > 0:
                seq_b = torch.cat(seq_list_b, dim=0)      # [num_pools*topk*Lp, D]
            else:
                seq_b = torch.empty(0, self.D, device=v.device)
            out_list.append(seq_b)
        prompts = torch.stack(out_list, dim=0)             # [B, num_pools*topk*Lp, D]
        return prompts

    def unique_indices_across_batch_per_pool(
        self,
        idx_sel: torch.Tensor
    ) -> List[torch.Tensor]:

        B, num_pools, topk = idx_sel.shape
        uniq = []
        for p in range(num_pools):
            cand = idx_sel[:, p, :].reshape(-1)  # [B*topk]
            uniq_p = torch.unique(cand)          # [U_p]
            uniq.append(uniq_p)
        return uniq

    def gather_values_per_pool_for_text(
        self,
        uniq_idx_per_pool: List[torch.Tensor]
    ) -> List[torch.Tensor]:

        per_pool_prompts = []
        for p in range(self.num_pools):
            idxs = uniq_idx_per_pool[p].tolist()
            if len(idxs) == 0:
                per_pool_prompts.append(torch.empty(0, self.D, device=self.values.device))
                continue
            seq_list = [self.values[p, m] for m in idxs]  
            per_pool_prompts.append(torch.cat(seq_list, dim=0))  
        return per_pool_prompts

    @torch.no_grad()
    def update_select_counts(self, idx_sel: torch.Tensor, sim_sel: torch.Tensor = None,
                            train_task_id: int = -1, test_task_id: int = -1, sample_weights: torch.Tensor = None, cur_train_task_id: int = -1):

        assert idx_sel.dim() == 3 and idx_sel.size(1) == self.num_pools, \
            f"idx_sel shape {idx_sel.shape} not match num_pools={self.num_pools}"

        B, P, K = idx_sel.shape

        if train_task_id == -1 and test_task_id == -1:
            one_hot = F.one_hot(idx_sel.to(torch.long), num_classes=self.M)
            counts_pm = one_hot.sum(dim=(0, 2)).to(self.select_counts.dtype)
            self.select_counts[cur_train_task_id] += counts_pm 
        elif train_task_id >= 0 and test_task_id >= 0:
            num_prompts_total = (train_task_id + 1) * self.M
            
            max_idx = idx_sel.max().item() if idx_sel.numel() > 0 else -1
            assert max_idx < num_prompts_total, \
                f"Test time index out of bounds. Max index in idx_sel is {max_idx}, but num_prompts_total is {num_prompts_total}"
            
            one_hot_counts = F.one_hot(idx_sel.to(torch.long), num_classes=num_prompts_total)
            counts_pm = one_hot_counts.sum(dim=(0, 2)).to(self.task_select_counts.dtype)
            self.task_select_counts[train_task_id, test_task_id, :, :num_prompts_total] += counts_pm
            one_hot_counts_float = one_hot_counts.to(torch.float32)

            sim_scores_expanded = sim_sel.unsqueeze(-1)

            sim_scores_per_prompt = (one_hot_counts_float * sim_scores_expanded).sum(dim=(0, 2))

            self.task_select_sims[train_task_id, test_task_id, :, :num_prompts_total] += sim_scores_per_prompt      
        else:
            pass

    @torch.no_grad()
    def get_select_freq(self, task_id: int, eps: float = 1e-12, normalize: str = "per_pool") -> torch.Tensor:
        freq = self.select_counts[task_id]  # [P, M]

        if normalize == "per_pool":
            denom = freq.sum(dim=1, keepdim=True).clamp_min(eps)  # [P,1]
            return freq / denom
        elif normalize == "global":
            denom = freq.sum().clamp_min(eps)
            return freq / denom
        else:
            raise ValueError(f"unknown normalize={normalize}")

    def get_task_counts(
        self,
        train_task_id: int = -1,
        test_task_id: int = -1,
        normalize: bool = True
    ):
        if train_task_id >= 0 and test_task_id >= 0:
            num_prompts = (train_task_id + 1) * self.M
            counts = self.task_select_counts[train_task_id, test_task_id, :, :num_prompts].detach().cpu().numpy()
            sim_sum = self.task_select_sims[train_task_id, test_task_id, :, :num_prompts].detach().cpu().numpy()
            avg_sims = np.divide(sim_sum, counts, out=np.zeros_like(sim_sum), where=counts!=0)
        else:
            counts = self.select_counts.detach().cpu().numpy() 
            avg_sims = None 

        if normalize:
            row_sums = counts.sum(axis=1, keepdims=True) + 1e-6
            counts = counts / row_sums

        return counts, avg_sims

    def plot_task_all_pools(self, train_task_id: int, test_task_id: int, normalize: bool = True, log_path: str = None):
        counts, avg_sims = self.get_task_counts(train_task_id=train_task_id, test_task_id=test_task_id, normalize=normalize)
        
        if counts is None or avg_sims is None:
            print("No valid counts or similarities for plotting.")
            return

        num_pools, num_prompts_total = counts.shape

        fig, axes = plt.subplots(num_pools, 1, figsize=(20, 3 * num_pools), sharex=True)
        if num_pools == 1:
            axes = [axes]

        for pool_id, ax in enumerate(axes):
            ax.bar(range(num_prompts_total), counts[pool_id], color='dodgerblue', alpha=0.7, label='Selection Frequency')
            
            ax.set_ylabel("Frequency", fontsize=14)
            ax.tick_params(axis='y', labelsize=12) 
            ax.set_ylim(0, counts.max() * 1.2) 

        axes[-1].set_xlabel("Prompt ID", fontsize=14) 
        axes[-1].tick_params(axis='x', labelsize=12) 
        plt.tight_layout()

        if log_path is not None:
            save_dir = osp.join(log_path, "figs")
            os.makedirs(save_dir, exist_ok=True)
            save_path = osp.join(save_dir, f"train_{train_task_id}_test_{test_task_id}_all_pools_with_sims.pdf")
            plt.savefig(save_path)
        plt.close()
    
    def update_sim_weight_stats(self, sim_weight: torch.Tensor, train_task_id: int, test_task_id: int):
        if train_task_id < 0 or test_task_id < 0:
            return
        
        # [B, P, K] -> [P]
        sim_pool_mean = sim_weight.mean(dim=(0, 2))  # [P]

        self.sim_weight_sums[train_task_id, test_task_id] += sim_pool_mean.detach()
        self.sim_weight_counts[train_task_id, test_task_id] += 1

    def plot_sim_weight_heatmaps(self, save_path: str = None):
        sums = self.sim_weight_sums.detach().cpu().numpy()
        counts = self.sim_weight_counts.detach().cpu().numpy()

        counts_expanded = np.maximum(counts[..., None], 1)
        heatmaps = sums / counts_expanded  # [train, test, P]
        heatmaps = heatmaps.transpose(2, 0, 1)  # -> [P, train, test]

        num_pools, nb_task, _ = heatmaps.shape

        fig, axes = plt.subplots(1, num_pools, figsize=(5*num_pools, 5), constrained_layout=True)
        if num_pools == 1:
            axes = [axes]

        for p in range(num_pools):
            sns.heatmap(
                heatmaps[p],
                xticklabels=list(range(nb_task)),
                yticklabels=list(range(nb_task)),
                annot=True, fmt=".2f", cmap="viridis",
                ax=axes[p], cbar=True
            )
            axes[p].set_title(f"Pool {p}")
            axes[p].set_xlabel("Test Task ID")
            axes[p].set_ylabel("Train Task ID")

        if save_path is not None:
            save_dir = osp.join(save_path, "figs")
            os.makedirs(save_dir, exist_ok=True)
            save_path = osp.join(save_dir, f"heatmaps.png")
            plt.savefig(save_path)
        plt.close()

    def compute_prompt_diversity(self, prompts: torch.Tensor) -> List[float]:
        num_pools, num_prompts_per_pool, Lp, D = prompts.shape
        diversities = []

        for p in range(num_pools):
            pool_prompts = prompts[p].reshape(num_prompts_per_pool * Lp, D).detach()
            singular_vals = torch.linalg.svdvals(pool_prompts)
            div = singular_vals.sum().item()
            diversities.append(div)

        return diversities

    def log_diversity(self, task_id: int):
        self.task_values_log.append(self.values.detach().clone())

        all_prompts = torch.cat(self.task_values_log, dim=1)

        divs = self.compute_prompt_diversity(prompts=all_prompts)
        
        divs_tensor = torch.tensor(divs, device=self.diversity_log.device).unsqueeze(0)

        if self.diversity_log.numel() == 0:
            self.diversity_log = divs_tensor
        else:
            self.diversity_log = torch.cat([self.diversity_log, divs_tensor], dim=0)

        self.task_ids.append(task_id)
        print(f"[Task {task_id}] Diversities = {divs}")

    def plot_diversity(self, save_path: str = None):

        if len(self.task_ids) == 0:
            print("No diversity log found!")
            return

        plt.figure()
        fig, ax = plt.subplots(figsize=(5, 4))
        num_pools = self.diversity_log.shape[1]
        for p in range(num_pools):
            plt.plot(
                self.task_ids,
                self.diversity_log[:, p].cpu().numpy(),
                marker="o",
                # label=f"Pool {p}"
            )

        plt.xlabel("Task ID")
        plt.ylabel("Prompt Diversity")
        plt.grid(True)

        if save_path is not None:
            save_dir = osp.join(save_path, "figs")
            os.makedirs(save_dir, exist_ok=True)
            save_path = osp.join(save_dir, f"diversity_curve.pdf")
            plt.savefig(save_path, bbox_inches="tight")
        plt.close(fig)

class ImagePromptAugmentor(nn.Module):
    def __init__(self, pool: MultiPromptPool, topk: int = 2):
        super().__init__()
        self.pool = pool
        self.topk = topk

    def forward(self, x_q: torch.Tensor, train_task_id: int = -1, test_task_id: int = -1, cur_train_task_id: int = -1) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        B, D = x_q.shape 
        assert D == self.pool.D, f"Token dim mismatch: x_e({D}) vs pool.D({self.pool.D})"
                        
        idx_sel, dist_sel, sim_weight = self.pool.topk_select(x_q, topk=self.topk, cur_train_task_id=cur_train_task_id)  # [B, P, K], [B, P, K]

        self.pool.update_select_counts(idx_sel, None, train_task_id, test_task_id, sample_weights=None, cur_train_task_id=cur_train_task_id)

        prom_seq_raw = self.pool.gather_values(idx_sel)  # [B, P*K*Lp, D]

        self.pool.update_sim_weight_stats(sim_weight, train_task_id, test_task_id)

        Lp = self.pool.Lp
        apply_modulation = (train_task_id >= 0 and test_task_id >= 0) 
        prom_seq = modulate_prompts(prom_seq_raw, sim_weight, Lp, apply_modulation=apply_modulation)

        reg_loss = dist_sel.mean()

        return prom_seq, idx_sel, reg_loss

    def select_from_all_tasks(
        self, 
        x_q: torch.Tensor, 
        test_cur_train_task_id: int, 
        test_cur_test_task_id: int,
        topk: int = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        if topk is None:
            topk = self.topk
        B, D = x_q.shape
        assert D == self.pool.D, f"x_q dim {D} vs pool.D {self.pool.D}"

        all_keys = []
        all_values = []
        for tid in range(test_cur_train_task_id + 1):
            if tid in self.pool.task_keys_dict:
                all_keys.append(self.pool.task_keys_dict[tid])   # [num_pools, M, Dk]
                all_values.append(self.pool.task_values_dict[tid])  # [num_pools, M, Lp, D]

        if len(all_keys) == 0:
            raise ValueError(f"No saved keys/values found up to task {test_cur_train_task_id}")

        all_keys = torch.cat(all_keys, dim=1)       # [num_pools, M_total, Dk]
        all_values = torch.cat(all_values, dim=1)   # [num_pools, M_total, Lp, D]

        prom_seqs = []
        sim_scores_all = []
        idx_all = []
        for p in range(self.pool.num_pools):
            keys_p = all_keys[p]       # [M_total, D]
            values_p = all_values[p]   # [M_total, Lp, D]

            sim = torch.matmul(
                F.normalize(x_q, dim=-1), 
                F.normalize(keys_p.to(x_q.dtype), dim=-1).t()
            )  # [B, M_total]

            sim_val, idx_sel = torch.topk(sim, k=topk, dim=-1)  # [B, topk], [B, topk]

            # gather value
            vals = []
            for b in range(B):
                vals.append(values_p[idx_sel[b]])  # [topk, Lp, D]
            vals = torch.stack(vals, dim=0)  # [B, topk, Lp, D]

            vals = vals.reshape(B, topk*self.pool.Lp, D)  # [B, topk*Lp, D]
            prom_seqs.append(vals)

            sim_scores_all.append(sim_val)  # [B, topk]
            idx_all.append(idx_sel)         # [B, topk]

        prom_seq = torch.cat(prom_seqs, dim=1)              # [B, num_pools*topk*Lp, D]
        sim_scores = torch.cat(sim_scores_all, dim=1)       # [B, num_pools*topk]
        prom_weights = sim_scores.mean(dim=-1)              # [B]

        idx_all = torch.stack(idx_all, dim=1)        # [B, num_pools, topk]
        sim_all = torch.stack(sim_scores_all, dim=1) # [B, num_pools, topk]

        self.pool.update_select_counts(idx_all, sim_all, test_cur_train_task_id, test_cur_test_task_id, sample_weights=None)

        apply_modulation = (test_cur_train_task_id >= 0 and test_cur_test_task_id >= 0)  
        if apply_modulation:
            B, N_Lp, D = prom_seq.shape        # N_Lp = num_pools*topk*Lp
            _, N = sim_scores.shape            # N    = num_pools*topk

            sim_scores_expanded = sim_scores.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, 4, D)

            prom_seq_grouped = prom_seq.view(B, N, 4, D)

            prom_seq_modulated = prom_seq_grouped * sim_scores_expanded

            prom_seq_modulated = prom_seq_modulated.view(B, N*4, D)

            return prom_seq_modulated, prom_weights

        return prom_seq, prom_weights

class TextPromptAugmentor(nn.Module):
    def __init__(self, pool: MultiPromptPool):
        super().__init__()
        self.pool = pool

    def forward(
        self,
        desc_per_pool: Dict[str, torch.Tensor], 
        idx_sel: torch.Tensor,                   
        pool_order: List[str] = None
    ) -> Dict[str, torch.Tensor]:
        pool_order = pool_order or self.pool.pool_names
        uniq_idx_per_pool = self.pool.unique_indices_across_batch_per_pool(idx_sel)           # list of [U_p]
        per_pool_prompts = self.pool.gather_values_per_pool_for_text(uniq_idx_per_pool)       # list of [U_p*Lp, D]

        out = {}
        for p_idx, name in enumerate(pool_order):
            assert name in desc_per_pool, f"Missing text description for pool '{name}'"
            desc = desc_per_pool[name]                          # [C, L_txt_p, D]
            prom = per_pool_prompts[p_idx]                     
            C, Ltxt, D = desc.shape
            if prom.numel() == 0:
                out[name] = desc
            else:
                prom_expand = prom.unsqueeze(0).expand(C, -1, -1)   # [C, U_p*Lp, D]
                out[name] = torch.cat([desc, prom_expand], dim=1)   # [C, L_txt_p + U_p*Lp, D]
        return out

import torch

def expand_and_add_prompts(
    x: torch.Tensor, 
    prom_seq: torch.Tensor, 
    P: int, K: int, Lp: int, 
    prepend: bool = False,
    use_prompt: bool = True 
) -> torch.Tensor:
    """
    Expand each sample in the batch into P copies and add corresponding prompts.

    Args:
        x (torch.Tensor): Input tensor of shape [B, L, D].
        prom_seq (torch.Tensor): Prompt tensor of shape [B, P*K*Lp, D].
        P (int): Number of prompt partitions per sample.
        K (int): Number of prompt groups per partition.
        Lp (int): Length of each prompt group.
        prepend (bool): If True, prepend prompts before x. 
                        If False, append prompts after x.
        use_prompt (bool): If False, ignore prom_seq and return [B*P, L, D].

    Returns:
        torch.Tensor: 
            If use_prompt=True: [B*P, L+K*Lp, D]  
            If use_prompt=False: [B*P, L, D]
    """
    B, L, D = x.shape
    _, total_prompts, _ = prom_seq.shape
    assert total_prompts == P * K * Lp, \
        f"prom_seq second dim should be P*K*Lp={P*K*Lp}, but got {total_prompts}"

    # Step 1: expand x -> [B, P, L, D]
    x_expanded = x.unsqueeze(1).expand(B, P, L, D)

    # Step 2: reshape -> [B*P, L, D]
    x_expanded = x_expanded.reshape(B * P, L, D)

    if not use_prompt:
        return x_expanded

    # Step 3: split prom_seq into P chunks -> [B, P, K*Lp, D]
    prom_seq_split = prom_seq.reshape(B, P, K * Lp, D)

    # Step 4: reshape -> [B*P, K*Lp, D]
    prom_seq_split = prom_seq_split.reshape(B * P, K * Lp, D)

    # Step 5: concatenate
    if prepend:
        out = torch.cat([prom_seq_split, x_expanded], dim=1)  # [B*P, K*Lp+L, D]
    else:
        out = torch.cat([x_expanded, prom_seq_split], dim=1)  # [B*P, L+K*Lp, D]

    return out
