import torch
import torch.nn.functional as F
import numpy as np
from scipy.stats import spearmanr
from typing import Union, Dict


# --- Compute MSE Loss (Masked) ---
def masked_mse(pred, target, mask):
    """
    Compute Masked Mean Squared Error (MSE).
    Args:
        pred (torch.Tensor): Predicted values (B, T) or (T,).
        target (torch.Tensor): Ground truth values (B, T) or (T,).
        mask (torch.Tensor): Binary mask (1 = valid, 0 = invalid).
    Returns:
        mse_score (float): Mean squared error between predicted and target values.
    """
    masked_pred = pred * mask
    masked_target = target * mask
    mse = F.mse_loss(masked_pred, masked_target, reduction='sum')
    mse /= mask.sum().clamp(min=1.0)  # Avoid division by zero
    return mse.item()


# --- Compute MAE Loss (Masked) ---
def masked_mae(pred, target, mask):
    """
    Compute Masked Mean Absolute Error (MAE).
    Args:
        pred (torch.Tensor): Predicted values (B, T) or (T,).
        target (torch.Tensor): Ground truth values (B, T) or (T,).
        mask (torch.Tensor): Binary mask (1 = valid, 0 = invalid).
    Returns:
        mae_score (float): Mean absolute error between predicted and target values.
    """
    masked_pred = pred * mask
    masked_target = target * mask
    abs_error = torch.abs(masked_pred - masked_target)
    mae = abs_error.sum() / mask.sum().clamp(min=1.0)  # Avoid division by zero
    return mae.item()


def masked_voc(pred, mask):
    """
    Compute Masked Value-Order Correlation (VOC).
    Args:
        pred (torch.Tensor): Predicted values (B, T) or (T,).
        mask (torch.Tensor): Binary mask (1 = valid, 0 = invalid).
    Returns:
        voc_score (float): Value-Order Correlation score between -1 and 1.
    """
    if isinstance(pred, torch.Tensor):
        pred = pred.detach().cpu().numpy()
    if isinstance(mask, torch.Tensor):
        mask = mask.detach().cpu().numpy()

    pred = pred.flatten()
    mask = mask.flatten()
    
    # Apply mask
    valid_indices = mask > 0
    valid_preds = pred[valid_indices]
    valid_times = np.arange(len(pred))[valid_indices]
    
    if len(valid_preds) < 2:
        return np.nan  # Not enough points to compute correlation
    
    voc_score, _ = spearmanr(valid_preds, valid_times)
    return voc_score


def compute_goal_baseline_clip_reward(clip_feature_extractor, frames, goals, baselines, alpha=0.5):
    """
    VLM-RM baseline reward function.
    
    Args:
        clip_feature_extractor: has .encode_image() and .encode_text()
        frames (Tensor): Shape (B, T, C, H, W).
        goals (List[str]): List of goal texts.
        baselines (List[str]): List of baseline texts.
        alpha (float): Interpolation strength between projected and raw embeddings (0-1).
    Returns:
        reward (Tensor): Reward scores per timestep (B, T).
    """
    B, T, C, H, W = frames.shape
    # 1) Encode images
    flat = frames.view(B * T, C, H, W)
    img_emb = clip_feature_extractor.encode_image(flat)        # (B*T, D)
    img_emb = F.normalize(img_emb, dim=-1).view(B, T, -1)      # (B, T, D)

    # 2) Encode texts (goal and baseline)
    g = clip_feature_extractor.encode_text(goals)              # (B, D)
    b = clip_feature_extractor.encode_text(baselines)          # (B, D)
    g = F.normalize(g, dim=-1).unsqueeze(1)                    # (B, 1, D)
    b = F.normalize(b, dim=-1).unsqueeze(1)                    # (B, 1, D)

    # 3) Direction vector in language space
    direction = F.normalize(g - b, dim=-1)                     # (B, 1, D)

    # 4) Project each img_emb[t] onto that direction
    dot_proj = torch.sum(img_emb * direction, dim=-1, keepdim=True)  # (B, T, 1)
    projection = dot_proj * direction                          # (B, T, D)

    # 5) Interpolate between raw and projected embeddings
    img_reg = alpha * projection + (1 - alpha) * img_emb        # (B, T, D)

    # 6) Compute squared distance to goal g
    g_exp = g.expand(-1, T, -1)                                # (B, T, D)
    dist2 = F.mse_loss(img_reg, g_exp, reduction='none').sum(dim=-1)  # (B, T)

    # 7) Convert distance to reward
    reward = 1 - 0.5 * dist2                                   # (B, T)
    return reward


def compute_contrastive_clip_reward(model, frames, goals, baselines, tau=0.01, beta=0.5):
    """
    Compute Contrastive CLIP reward based on cosine similarity between frames and goal texts.
    
    Args:
        model: Model with a `clip_feature_extractor` for encoding frames and texts.
        frames (Tensor): Shape (B, T, C, H, W).
        goals (List[str]): List of goal texts.
        baselines (List[str]): List of baseline texts.
        tau (float): Temperature for similarity scaling.
        beta (float): Threshold for reward assignment.
        
    Returns:
        goal_prob (Tensor): Probability of the goal being the correct choice (B, T).
    """
    if isinstance(goals, str):          # Single goal string -> wrap
        goals = [goals]

    if isinstance(baselines, str):      # Single baseline string -> use same for every sample
        baselines = [[baselines] for _ in range(len(goals))]
    elif all(isinstance(b, str) for b in baselines):
        baselines = [[b] for b in baselines]  # Wrap each baseline string
    else:
        pass  # Already List[List[str]]
    
    device = next(model.clip_feature_extractor.visual_encoder.parameters()).device
    frames = frames.to(device)

    B, T = frames.shape[:2]
    rewards = torch.zeros(B, T, device=device)

    for i in range(B):
        candidates = [goals[i]] + baselines[i]
        K = len(candidates)
        seq = frames[i:i+1].repeat(K, 1, 1, 1, 1)  # Repeat frames K times
        sims = model.compute_clip_similarity_score(seq, candidates)
        
        logits = sims / tau        # (K, T)
        probs  = F.softmax(logits, dim=0)  # Softmax across candidate dimension

        goal_prob = probs[0]      # The probability of the true goal (B, T)
        rewards[i] = (goal_prob > beta).float()

    return goal_prob  # (B, T)


def compute_clip_similarity_score(self, frames, goals):
    """
    Compute CLIP similarity score between image frames and goal texts.
    
    Args:
        frames (Tensor): Shape (B, T, C, H, W).
        goals (List[str]): List of goal texts.
        
    Returns:
        clip_scores (Tensor): Cosine similarity scores between frames and goals (B, T).
    """
    B, T, C, H, W = frames.shape
    all_f = frames.view(-1, C, H, W)
    
    img_emb = self.clip_feature_extractor.encode_image(all_f).view(B, T, -1)  # (B, T, D)
    txt_emb = self.clip_feature_extractor.encode_text(goals).unsqueeze(1).expand(-1, T, -1)  # (B, T, D)

    # Normalize for cosine similarity
    img_emb_norm = F.normalize(img_emb, dim=-1)
    txt_emb_norm = F.normalize(txt_emb, dim=-1)

    # Cosine similarity along feature dimension
    clip_scores = (img_emb_norm * txt_emb_norm).sum(dim=-1)  # (B, T), values in [-1, 1]

    # Normalize to [0, 1] for consistency
    clip_scores = (clip_scores + 1.0) / 2.0

    return clip_scores  # (B, T)
