import torch
import torch.nn.functional as F
import os
import json
import torch


def apply_bw(x: torch.Tensor, out: torch.Tensor, bw: torch.Tensor):

    if bw is None:
        return out

    bw = bw.to(device=x.device, dtype=x.dtype)

    if bw.dim() == 0:
        bw = bw.unsqueeze(0)

    if bw.numel() == x.shape[1]:  
        bw = bw.view(1, x.shape[1], 1)
    elif bw.numel() == x.shape[0]:
        bw = bw.view(x.shape[0], 1, 1)
    elif bw.numel() * (x.shape[1] // bw.numel()) == x.shape[1]:  
        P = x.shape[1] // bw.numel()
        bw = bw.repeat_interleave(P).view(1, x.shape[1], 1)
    else:
        raise RuntimeError(
            f"batch_weight length {bw.numel()} doesn't match batch dim of x {x.shape[:2]}"
        )

    return out * bw

def pool_features(
    x: torch.Tensor, 
    K: int, Lp: int, 
    pooling_type: str = "cls", 
    prepend: bool = False,
    has_prompt: bool = True,
) -> torch.Tensor:
    """
    Pool features after prompt concatenation.

    Args:
        x (torch.Tensor): Input tensor of shape [B*P, L+K*Lp, D] 
                          (or [B, L, D] if has_prompt=False).
        K (int): Number of prompt groups per partition.
        Lp (int): Length of each prompt group.
        pooling_type (str): "cls", "prompt", or "fusion".
        prepend (bool): Whether prompts are prepended.
        has_prompt (bool): Whether prompts were added (prom_seq is not None).

    Returns:
        torch.Tensor: Pooled feature of shape [B*P, D].
    """
    if pooling_type == "cls":
        if not has_prompt:
            feat = x[:, 0, :]  
        else:
            if prepend:
                feat = x[:, K*Lp, :]   
            else:
                feat = x[:, 0, :]      
    elif pooling_type == "prompt":
        if not has_prompt:
            raise ValueError("Prompt pooling requested but no prompts were added.")
        if prepend:
            feat = x[:, :K*Lp, :].mean(dim=1) 
        else:
            feat = x[:, -K*Lp:, :].mean(dim=1) 
    elif pooling_type == "fusion":
        if not has_prompt:
            feat = x[:, 0, :]  
        else:
            if prepend:
                cls_feat = x[:, K*Lp, :]
                prompt_feat = x[:, :K*Lp, :].mean(dim=1)
            else:
                cls_feat = x[:, 0, :]
                prompt_feat = x[:, -K*Lp:, :].mean(dim=1)
            feat = cls_feat + prompt_feat
    else:
        raise ValueError(f"Invalid pooling_type: {pooling_type}")

    return feat

def compute_prompt_weights(sim_sel: torch.Tensor, 
                           method: str = "softmax", 
                           tau: float = 1.0, 
                           threshold: float = 0.0) -> torch.Tensor:

    if method == "softmax":
        weights = F.softmax(sim_sel / tau, dim=-1)

    elif method == "sigmoid":
        weights = torch.sigmoid(sim_sel)

    elif method == "raw":
        weights = torch.relu(sim_sel)  
        if threshold > 0.0:
            weights = torch.where(weights < threshold, torch.zeros_like(weights), weights)

    else:
        raise ValueError(f"Unknown method: {method}. Choose from ['softmax', 'sigmoid', 'raw'].")

    return weights

def modulate_prompts(prom_seq_raw: torch.Tensor,
                     sim_weight: torch.Tensor,
                     Lp: int,
                     apply_modulation: bool = True) -> torch.Tensor:
    B, P, K = sim_weight.shape
    D = prom_seq_raw.size(-1)

    if apply_modulation:
        sim_weight_exp = (
            sim_weight.unsqueeze(-1)     # [B, P, K, 1]
                     .unsqueeze(-1)     # [B, P, K, 1, 1]
                     .expand(B, P, K, Lp, D)  # [B, P, K, Lp, D]
        )
        prom_seq_raw = prom_seq_raw.view(B, P, K, Lp, D) * sim_weight_exp

    prom_seq = prom_seq_raw.view(B, P * K * Lp, D)
    return prom_seq

def log_sim_weight(sim_weight: torch.Tensor, train_task_id: int, test_task_id: int, log_path: str):

    os.makedirs(os.path.dirname(log_path), exist_ok=True)

    # [B, P, K] -> [P]
    sim_pool_mean = sim_weight.mean(dim=(0, 2)).detach().cpu().tolist()

    key = f"{train_task_id}-{test_task_id}"

    if os.path.exists(log_path):
        with open(log_path, "r") as f:
            log_data = json.load(f)
    else:
        log_data = {}

    if key not in log_data:
        log_data[key] = []

    log_data[key].append(sim_pool_mean)

    with open(log_path, "w") as f:
        json.dump(log_data, f, indent=4)