import wandb
import matplotlib.pyplot as plt
import numpy as np
import torch

def evaluate_traj_loss(pred_traj, targ_traj):
    """
    Compute loss between predicted and target trajectories for each time step.
    
    Args:
        pred_traj (torch.Tensor): Predicted trajectory of shape (batch_size, n_steps, feature_dim).
        targ_traj (torch.Tensor): Target trajectory of shape (batch_size, n_steps, feature_dim).
        
    Returns:
        tuple: (
            step_loss (torch.Tensor): Normalized loss at each time step averaged over the batch,
                                      shape (n_steps,)
            traj_loss (float): Total L2 error across the trajectory averaged over the batch
        )
    """
    # Get dimensions
    batch, steps, _ = pred_traj.shape

    # Reshape to compute errors more efficiently
    pred_flat = pred_traj.reshape(batch * steps, -1) 
    targ_flat = targ_traj.reshape(batch * steps, -1) 

    # Compute L2 norm of difference for each sample and step
    error = torch.norm(pred_flat - targ_flat, 2, dim=-1)
    
    # Normalize error by the L2 norm of target
    normalized_error = error / (torch.norm(targ_flat, 2, dim=-1) + 1e-7)

    # Reshape back to separate batch and time dimensions
    error = error.reshape(batch, steps)
    normalized_error = normalized_error.reshape(batch, steps)

    # Calculate step-wise normalized loss (averaged over batch)
    step_loss = normalized_error.mean(dim=0)
    
    # Calculate total trajectory loss by computing L2 norm across time and feature dimensions, then average over batch
    traj_loss = torch.norm(pred_traj - targ_traj, 2, dim=tuple(range(1, pred_traj.dim()))).mean()

    return step_loss, traj_loss

def evaluate_traj_corr(pred_traj, targ_traj):
    """
    Evaluate correlation between predicted and target trajectories.
    
    Computes:
    1. Average correlation per time step
    2. For a range of correlation thresholds, how many steps maintain correlation above threshold
    
    Args:
        pred_traj (torch.Tensor): Predicted trajectory of shape (batch_size, n_steps, feature_dim).
        targ_traj (torch.Tensor): Target trajectory of shape (batch_size, n_steps, feature_dim).
        
    Returns:
        tuple: (
            step_corr (torch.Tensor): Mean correlation at each time step,
            steps_above_threshold_avg (torch.Tensor): Average number of steps where correlation 
                                                 stays above each threshold value,
            steps_above_threshold_min (torch.Tensor): Minimum number of steps where correlation
                                              stays above each threshold value
        )
    """
    batch_size, n_steps, _ = pred_traj.shape
    
    # Flatten batch and time dimensions to compute per-sample correlations
    # From (batch_size, n_steps, feature_dim) to (batch_size*n_steps, feature_dim)
    pred_flat = pred_traj.reshape(-1, pred_traj.shape[-1])
    targ_flat = targ_traj.reshape(-1, targ_traj.shape[-1])

    # Compute correlation coefficients between predicted and target features
    # Result shape: (batch_size*n_steps, 1)
    corr_flat = batched_corrcoef(pred_flat, targ_flat)
    
    # Reshape to (batch_size, n_steps) to get correlation per trajectory and time step
    corr = corr_flat.reshape(batch_size, n_steps)

    # Compute average correlation at each time step across all trajectories
    # This shows how correlation typically degrades over time
    step_corr = corr.mean(dim=0)

    # Create correlation thresholds from 0.5 to 1.0 with 51 steps
    # Focus on higher correlation values which are more interesting for analysis
    corr_thresholds = torch.linspace(0.5, 1, 51, device=pred_traj.device)
    
    # Pre-allocate tensors for results
    num_thresholds = corr_thresholds.shape[0]
    steps_above_threshold_avg = torch.zeros(num_thresholds, device=pred_traj.device)
    steps_above_threshold_min = torch.zeros(num_thresholds, device=pred_traj.device)
    
    # For each threshold, compute how long the prediction remains accurate
    for i, threshold in enumerate(corr_thresholds):
        # Find the last time step where correlation is above threshold for each trajectory
        steps_above_threshold = find_corr_threshold_index(corr, threshold)
        
        # Calculate statistics across trajectories
        steps_above_threshold_avg[i] = steps_above_threshold.mean()
        steps_above_threshold_min[i] = steps_above_threshold.min()

    return step_corr, steps_above_threshold_avg, steps_above_threshold_min

def batched_corrcoef(pred: torch.Tensor, targ: torch.Tensor) -> torch.Tensor:
    """
    Compute Pearson correlation coefficient for batched inputs.
    
    Args:
        pred: Tensor of shape (batch_size, n_features)
        targ: Tensor of shape (batch_size, n_features)
        
    Returns:
        Tensor of shape (batch_size, 1) containing correlation coefficients
    """
    # Center the variables
    pred_centered = pred - pred.mean(dim=1, keepdim=True)
    targ_centered = targ - targ.mean(dim=1, keepdim=True)
    
    # Compute numerator (covariance)
    numerator = (pred_centered * targ_centered).sum(dim=1, keepdim=True)
    
    # Compute denominator (product of standard deviations)
    pred_std = torch.sqrt((pred_centered ** 2).sum(dim=1, keepdim=True))
    targ_std = torch.sqrt((targ_centered ** 2).sum(dim=1, keepdim=True))
    denominator = pred_std * targ_std
    
    # Handle zero division
    correlation = torch.where(
        denominator > 0,
        numerator / denominator,
        torch.zeros_like(numerator)
    )
    
    return correlation




def find_corr_threshold_index(corr, threshold):
    """Find last index where correlation stays above the threshold for each trajectory.
    Uses vectorized operations for efficient computation.
    
    Args:
        corr: Correlation tensor of shape (n_trajectories, n_timesteps)
        threshold: Correlation threshold value
    Returns:
        Tensor of indices (as float type) indicating:
        - Last position where correlation >= threshold for each trajectory
        - sequence length if correlation never drops below threshold
        - 0 if correlation starts below threshold
    """
    # Create mask where True indicates correlation < threshold
    mask = (corr < threshold).float()
    
    # If first element is below threshold, return 0
    starts_below = mask[:, 0] == 1
    
    # Find first occurrence where correlation drops below threshold
    first_below = mask.argmax(dim=1) - 1
    
    # Handle special cases:
    # 1. If correlation never drops below threshold (argmax returns 0)
    # 2. If correlation starts below threshold
    first_below[first_below == -1] = corr.shape[1] - 1 # Never drops case
    first_below[starts_below] = 0  # Starts below case
    
    return first_below.float()

def evaluate_traj(pred_traj, targ_traj, visualize=False):
    """
    Comprehensive evaluation of trajectory prediction performance.
    
    This function combines both loss-based and correlation-based metrics to provide
    a complete assessment of trajectory prediction quality. It serves as a convenient
    wrapper around the specific evaluation functions.
    
    Args:
        pred_traj (torch.Tensor): Predicted trajectory of shape (batch_size, n_steps, feature_dim).
        targ_traj (torch.Tensor): Target trajectory of shape (batch_size, n_steps, feature_dim).
        visualize (bool, optional): If True, generates visualization plots. Default is False.
        
    Returns:
        tuple: (
            step_loss (torch.Tensor): Loss at each time step of shape (n_steps-1,),
            traj_loss (float): Mean loss over all time steps,
            step_corr (torch.Tensor): Mean correlation at each time step,
            steps_above_threshold_avg (list): Average number of steps where correlation 
                                             stays above each threshold value,
            steps_above_threshold_min (list): Minimum number of steps where correlation
                                             stays above each threshold value,
            plots (list or None): List of wandb.Image objects if visualize=True, None otherwise
        )
    """
    # Compute loss-based metrics
    step_loss, traj_loss = evaluate_traj_loss(pred_traj, targ_traj)
    
    # Compute correlation-based metrics
    step_corr, steps_above_threshold_avg, steps_above_threshold_min = evaluate_traj_corr(pred_traj, targ_traj)

    # Initialize plots as None
    plots = None
    
    # Generate visualization plots if requested
    if visualize:
        n = 10  # Limit number of plots
        plots = []
        for i in range(min(n, pred_traj.shape[0])):
            fig = visualize_traj(
                pred_traj=pred_traj[i, :, :],
                targ_traj=targ_traj[i, :, :]
            )
            plots.append(wandb.Image(fig))
    
    return step_loss, traj_loss, step_corr, steps_above_threshold_avg, steps_above_threshold_min, plots
