import torch
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from generate_lap import generate_lap_dataset
from frobenius.LFRCov_Fro import LFRCovFrobenius_torch

def frobenius_distance(S1, S2):
    """
    Frobenius distance between two SPD matrices:
    d(S1,S2) = ||S1 - S2||_F
    """
    # Ensure matrices are symmetric
    S1 = 0.5 * (S1 + S1.transpose(-1, -2))
    S2 = 0.5 * (S2 + S2.transpose(-1, -2))
    
    # Calculate Frobenius norm of the difference
    return torch.norm(S1 - S2, p='fro')

def plot_matrix_heatmaps(predicted_matrices, conditional_means, n_show=10):
    n_show = min(n_show, len(predicted_matrices))
    fig, axes = plt.subplots(2, n_show, figsize=(3*n_show, 6))
    fig.suptitle('Matrix Comparison: Predicted vs Truth (Conditional Mean)', fontsize=16)
    
    for i in range(n_show):
        if n_show == 1:
            ax1, ax2 = axes[0], axes[1]
        else:
            ax1, ax2 = axes[0, i], axes[1, i]
        
        sns.heatmap(predicted_matrices[i].numpy(), ax=ax1, cmap='Greys', cbar=False, xticklabels=False, yticklabels=False, annot=True, fmt='.3f', annot_kws={"size":8})
        ax1.set_title(f'Predicted {i}', fontsize=10)
        
        sns.heatmap(conditional_means[i].numpy(), ax=ax2, cmap='Greys', cbar=False, xticklabels=False, yticklabels=False, annot=True, fmt='.3f', annot_kws={"size":8})
        ax2.set_title(f'Truth {i}', fontsize=10)
    
    plt.tight_layout()
    plt.show()

def run_local_frechet_regression(n=100, m=1000, d=3, random_seed=1):
    print("Local Fréchet Regression with Generated SPD Matrices")
    print("=" * 60)
    
    predictors, matrices_list, conditional_means = generate_lap_dataset(n=n, m=m, q=d, random_seed=random_seed, dtype=torch.float32)
    
    print(f"\nGenerated data:")
    print(f"  Predictors: {predictors.shape}")
    print(f"  Matrix sets: {len(matrices_list)}")
    print(f"  Conditional means: {len(conditional_means)}")
    
    # Use every 10th point as output points (10 total)
    output_indices = list(range(0, n, 10))
    training_indices = [i for i in range(n) if i not in output_indices]
    
    print(f"\nData split:")
    print(f"  Training points: {len(training_indices)} (indices: {training_indices[:5]}...{training_indices[-5:]})")
    print(f"  Output points: {len(output_indices)} (indices: {output_indices})")
    
    # Split data
    x_train = predictors[training_indices]
    x_out = predictors[output_indices]
    M_train = [matrices_list[i] for i in training_indices]
    conditional_means_out = [conditional_means[i] for i in output_indices]
    
    print(f"  Training data: {x_train.shape}")
    print(f"  Output data: {x_out.shape}")
    
    # Project predictors and output points using theta_0
    theta_0 = torch.tensor([0.5, 0.1, 0.0, -0.5])
    theta_0 = theta_0 / torch.norm(theta_0)
    
    x_train_projected = x_train @ theta_0
    x_out_projected = x_out @ theta_0
    
    print(f"\nProjected data:")
    print(f"  x_train_projected: {x_train_projected.shape} (range: [{x_train_projected.min():.4f}, {x_train_projected.max():.4f}])")
    print(f"  x_out_projected: {x_out_projected.shape}")
    
    # Set bandwidth for LFR
    bwCov = torch.tensor([0.1])
    
    print(f"Bandwidth: {bwCov}")
    print(f"Theta_0: {theta_0}")
    
    # Run Local Fréchet Regression
    print(f"\nRunning Local Fréchet Regression with Frobenius metric...")
    
    try:
        # Convert inputs to numpy as required by the regression implementation
        x_tensor = x_train_projected.reshape(-1, 1).to(dtype=torch.float64)
        xout_tensor = x_out_projected.reshape(-1, 1).to(dtype=torch.float64)
        h_float = float(bwCov.item()) if bwCov.numel() == 1 else float(bwCov.flatten()[0].item())
        Y_tensors = [
            M.detach().to(dtype=torch.float64) if isinstance(M, torch.Tensor)
            else torch.from_numpy(M).to(dtype=torch.float64)
            for M in M_train
        ]
        
        result = LFRCovFrobenius_torch(
            x=x_tensor,
            M=Y_tensors,
            xout=xout_tensor,
            h=h_float,
            kernel='gauss',
            dtype=torch.float64,
            device=x_tensor.device
        )
        
        print(f"LFR completed successfully!")
        xout_res = result["xout"].detach().cpu().numpy()
        S_hat_list = result["Mout"]
        print(f"Output points: {xout_res.shape}")
        
        # Convert predictions back to torch for downstream metrics/plots
        predicted_matrices = [S.detach().cpu() for S in S_hat_list]
        print(f"Predicted matrices: {len(predicted_matrices)}")
        
        # Compare predicted values with conditional means
        print(f"\n" + "="*60)
        print("COMPARISON: LFR Predictions vs Conditional Means")
        print("="*60)
        
        distances = []
        for i in range(len(predicted_matrices)):
            predicted_matrix = predicted_matrices[i]
            conditional_mean_matrix = conditional_means_out[i]
            
            distance = frobenius_distance(predicted_matrix, conditional_mean_matrix)
            distances.append(distance.item())
            
            print(f"Output point {i} (original index {output_indices[i]}):")
            print(f"  Frobenius distance: {distance:.6f}")
            
            if i < 3:
                print(f"  LFR Prediction eigenvalues: {torch.linalg.eigvals(predicted_matrix).real}")
                print(f"  Conditional Mean eigenvalues: {torch.linalg.eigvals(conditional_mean_matrix).real}")
                print(f"  Difference (Frobenius norm): {torch.norm(predicted_matrix - conditional_mean_matrix, p='fro'):.6f}")
            print()
        
        print(f"Summary of Frobenius Distances:")
        print(f"  Mean distance: {sum(distances)/len(distances):.6f}")
        print(f"  Min distance: {min(distances):.6f}")
        print(f"  Max distance: {max(distances):.6f}")
        print(f"  Std distance: {torch.std(torch.tensor(distances)):.6f}")
        
        # Show heatmaps for 10 predicted vs truth matrices
        print(f"\nGenerating heatmaps for {min(10, len(predicted_matrices))} predicted vs truth matrices...")
        predicted_matrices_for_plot = predicted_matrices[:10]
        conditional_means_for_plot = conditional_means_out[:10]
        
        plot_matrix_heatmaps(predicted_matrices_for_plot, conditional_means_for_plot)
        
        return predictors, matrices_list, conditional_means, (xout_res, S_hat_list)
        
    except Exception as e:
        print(f"Error in LFR algorithm: {e}")
        return predictors, matrices_list, conditional_means, None

if __name__ == "__main__":
    n = 100
    m = 1000
    d = 10
    seed = 1
    
    results = run_local_frechet_regression(n=n, m=m, d=d, random_seed=seed)
    
    if results[3] is not None:
        print(f"\n✅ Local Fréchet Regression completed successfully!")
        print(f"   Input: {n} predictor points with {m} SPD matrices each")
        print(f"   Training: 90 points, Output: 10 points")
        print(f"   All matrices verified as SPD")
    else:
        print(f"\n❌ Local Fréchet Regression failed")
