import torch
import matplotlib.pyplot as plt
import os


# --- Feature Explained Variance Plotting Utility ---
def plot_feature_explained_variance_dino(
    fs,
    save_name='feature_explained_variance_dino',
    save_dir='svd_features',
    device='cuda' if torch.cuda.is_available() else 'cpu'
):
    """
    fs: tuple(f, t) where f is [B, C, H, W] tensor
    Computes SVD over all batch and spatial locations and plots cumulative explained variance.
    """
    import os
    os.makedirs(save_dir, exist_ok=True)
    f, t = fs
    f = f.to(device)
    B, C, H, W = f.shape

    # Flatten batch and spatial dims: [C, B*H*W]
    mat = f.permute(1, 0, 2, 3).reshape(C, B * H * W)

    # Optionally center data
    # mat = mat - mat.mean(dim=1, keepdim=True)

    # Compute singular values
    sv = torch.linalg.svdvals(mat).cpu().numpy()

    # Explained variance ratio
    var_explained = sv**2
    var_ratio = var_explained / var_explained.sum()
    cum_var = var_ratio.cumsum()

    # Plot cumulative explained variance
    plt.figure(figsize=(5,4))
    plt.plot(cum_var, marker='o', linestyle='-')
    plt.xlabel('Component index')
    plt.ylabel('Cumulative explained variance')
    plt.title('Feature Map Explained Variance (DINOv2)')
    plt.grid(True)
    plt.tight_layout()
    # save with batch info in filename
    plt.savefig(os.path.join(save_dir, f'{save_name}_batch{B}.png'))
    plt.close()


def plot_feature_explained_variance_comparison(
    fs_full,
    fs_adapter,
    save_name='feature_explained_variance_comparison',
    save_dir='svd_features',
    labels=(
        'Partial Tuned-2',
        'Freeze + Partial Adapter-2'
        ),
    device='cuda' if torch.cuda.is_available() else 'cpu'
):
    """
    Compare cumulative explained variance for two feature map sets.
    
    Args:
        fs_full: tuple(f_full, t_full) where f_full is [B, C, H, W] tensor from full fine-tune model
        fs_adapter: tuple(f_adapter, t_adapter) where f_adapter is [B, C, H, W] tensor from freeze+adapter model
        save_name: base filename for saving plot
        save_dir: directory to save plots
        labels: tuple of two labels for the legend
        device: 'cpu' or 'cuda'
    """
    os.makedirs(save_dir, exist_ok=True)
    
    def compute_cumvar(f):
        # Move to device
        f = f.to(device)
        B, C, H, W = f.shape
        # Flatten batch and spatial dims: [C, B*H*W]
        mat = f.permute(1, 0, 2, 3).reshape(C, B * H * W)
        # Compute singular values
        sv = torch.linalg.svdvals(mat).cpu().numpy()
        # Explained variance ratio
        var_explained = sv**2
        var_ratio = var_explained / var_explained.sum()
        cum_var = var_ratio.cumsum()
        return cum_var
    
    # Extract feature maps
    f_full, _ = fs_full
    f_adapter, _ = fs_adapter
    
    # Compute cumulative variance
    cum_full = compute_cumvar(f_full)
    cum_adapter = compute_cumvar(f_adapter)
    # Component indices
    idx_full = torch.arange(1, len(cum_full)+1).numpy()
    idx_adapter = torch.arange(1, len(cum_adapter)+1).numpy()
    
    # Compute Rank@90%
    rank_full = idx_full[cum_full >= 0.9][0]
    rank_adapter = idx_adapter[cum_adapter >= 0.9][0]
    
    # Plot comparison
    plt.figure(figsize=(6,4))
    plt.plot(idx_full, cum_full, marker='o', markersize=3, label=f'{labels[0]}')
    plt.plot(idx_adapter, cum_adapter, marker='s', markersize=3, label=f'{labels[1]}')
    plt.axvline(rank_full, color='blue', linestyle='--',
                label=f'Rank@90% ({labels[0]}) = {rank_full}')
    plt.axvline(rank_adapter, color='orange', linestyle='--',
                label=f'Rank@90% ({labels[1]}) = {rank_adapter}')
    plt.xlabel('Component Index')
    plt.ylabel('Cumulative Explained Variance')
    plt.title('Feature Map Explained Variance Comparison')
    plt.legend(loc='lower right')
    plt.grid(True)
    plt.tight_layout()
    # Save plot
    plt.savefig(os.path.join(save_dir, f'{save_name}.png'))
    plt.close()