import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import torch
from tqdm import tqdm
import os
import sys
from mpl_toolkits.axes_grid1.inset_locator import inset_axes, mark_inset

# Set PATH for LaTeX FIRST
os.environ['PATH'] = '/Library/TeX/texbin:' + os.environ.get('PATH', '')

# Apply style BEFORE LaTeX settings
plt.style.use('seaborn-v0_8-darkgrid')

# Now set LaTeX configuration (this will override style settings)
plt.rcParams.update({
    'text.usetex': True,
    'pgf.texsystem': 'pdflatex',
    'font.family': 'serif',
    'pgf.rcfonts': False,
    'font.size': 28,
    'axes.labelsize': 20,
    'axes.titlesize': 20,
    'legend.fontsize': 14,
    'xtick.labelsize': 20,
    'ytick.labelsize': 20,
    # White background settings
    'figure.facecolor': 'white',
    'axes.facecolor': 'white',
    'savefig.facecolor': 'white',
    'axes.edgecolor': 'black',
    'axes.linewidth': 1.0,
    'grid.color': 'gray',
    'grid.alpha': 0.3,
    'grid.linestyle': '-',
    'grid.linewidth': 0.5,
})

USE_LATEX = True

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from magnitude import norm_diff_magnitude_distance_grad


def compute_wasserstein_distance_gpu(X, Y, device='cuda'):
    """
    Compute 1-Wasserstein distance between two point clouds using GPU.
    Uses the POT (Python Optimal Transport) library with GPU support.
    
    Parameters:
    -----------
    X : array-like, shape (n, d)
        First point cloud
    Y : array-like, shape (m, d)
        Second point cloud
    device : str
        Device for computation ('cpu' or 'cuda')
    
    Returns:
    --------
    float
        Wasserstein distance
    """
    try:
        import ot
        
        # Convert to torch tensors on GPU
        X_torch = torch.tensor(X, dtype=torch.float32, device=device)
        Y_torch = torch.tensor(Y, dtype=torch.float32, device=device)
        
        n = X_torch.shape[0]
        m = Y_torch.shape[0]
        
        # Uniform distributions
        a = torch.ones(n, device=device) / n
        b = torch.ones(m, device=device) / m
        
        # Compute cost matrix (Euclidean distances)
        M = torch.cdist(X_torch, Y_torch, p=2)
        
        # Compute Wasserstein distance
        # Convert to numpy for POT (POT backend handles GPU if available)
        a_np = a.cpu().numpy()
        b_np = b.cpu().numpy()
        M_np = M.cpu().numpy()
        
        w_dist = ot.emd2(a_np, b_np, M_np)
        
        return w_dist
        
    except ImportError:
        print("Warning: POT library not found. Installing it is recommended for accurate Wasserstein distances.")
        print("Falling back to approximate method...")
        
        # Fallback: compute average Wasserstein over dimensions
        X_torch = torch.tensor(X, dtype=torch.float32, device=device)
        Y_torch = torch.tensor(Y, dtype=torch.float32, device=device)
        
        distances = []
        for dim in range(X.shape[1]):
            x_sorted = torch.sort(X_torch[:, dim])[0]
            y_sorted = torch.sort(Y_torch[:, dim])[0]
            dist = torch.mean(torch.abs(x_sorted - y_sorted)).item()
            distances.append(dist)
        
        return np.mean(distances)


def run_outlier_experiment(D=5, n=500, epsilon_values=[0.01], t=0.1, device='cuda', seed=42):
    """
    Run the outlier corruption experiment on GPU for multiple epsilon values.
    
    Parameters:
    -----------
    D : int
        Dimension of the data
    n : int
        Number of samples
    epsilon_values : list of float
        List of contamination levels
    t : float
        Scaling parameter for magnitude distance
    device : str
        Device for computation ('cpu' or 'cuda')
    seed : int
        Random seed
    
    Returns:
    --------
    dict
        Dictionary with epsilon values as keys and (radii, wasserstein_distances, magnitude_distances) as values
    """
    # Check device availability
    if device == 'cuda' and not torch.cuda.is_available():
        print("Warning: CUDA not available. Falling back to CPU.")
        device = 'cpu'
    
    print("="*60)
    print("OUTLIER CORRUPTION EXPERIMENT")
    print("="*60)
    print(f"Device: {device}")
    if device == 'cuda':
        print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Dimension: {D}")
    print(f"Number of samples: {n}")
    print(f"Contamination levels: {epsilon_values}")
    print(f"Magnitude distance t: {t}")
    print("="*60)
    print()
    
    # Define outlier radii
    radii = [2, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 150, 200, 250, 300, 
             350, 400, 450, 500, 550, 600, 650, 700, 750, 800, 850, 900, 950, 1000]
    # radii = [2, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 150, 200, 250, 300, 
    #          350, 400, 450, 500, 550, 600, 650, 700, 750, 800, 850, 900, 950, 1000]
    # radii = [2, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 150, 200, 250, 300, 350, 400, 450, 500]
    
    results = {}
    
    for epsilon in epsilon_values:
        print(f"\n{'='*60}")
        print(f"Processing epsilon = {epsilon}")
        print(f"{'='*60}")
        
        # Set random seeds
        np.random.seed(seed)
        torch.manual_seed(seed)
        if device == 'cuda' and torch.cuda.is_available():
            torch.cuda.manual_seed(seed)
        
        # Number of outliers
        k = int(np.floor(epsilon * n))
        print(f"Number of outliers: {k}")
        
        # Generate clean data
        X = np.random.randn(n, D).astype(np.float32)
        
        # Storage for results
        wasserstein_distances = []
        magnitude_distances = []
        
        print("Computing distances for each outlier radius...")
        for R in tqdm(radii, desc=f"Processing radii (ε={epsilon})"):
            # Create contaminated data
            Y = X.copy()
            
            # Randomly select k indices for outliers
            outlier_indices = np.random.choice(n, size=k, replace=False)
            
            # Generate outliers on sphere of radius R
            for idx in outlier_indices:
                u = np.random.randn(D).astype(np.float32)
                u = u / np.linalg.norm(u)  # Normalize to unit sphere
                Y[idx] = R * u  # Scale to radius R
            
            # Compute Wasserstein distance
            w_dist = compute_wasserstein_distance_gpu(X, Y, device=device)
            wasserstein_distances.append(w_dist)
            
            # Convert NumPy arrays to PyTorch tensors
            X_tensor = torch.from_numpy(X).to(device)
            Y_tensor = torch.from_numpy(Y).to(device)

            m_dist = norm_diff_magnitude_distance_grad(
                X_tensor, Y_tensor, 
                device=str(device), 
                t=t, 
                normalize=False, 
                eps=0
            )
            magnitude_distances.append(m_dist)
        
        results[epsilon] = (radii, wasserstein_distances, magnitude_distances)
    
    print("\nAll experiments completed!")
    return results


def plot_results(results, D=5, n=500, t=0.1, save_dir='outlier_plots'):
    """
    Plot the results of the outlier experiment for multiple epsilon values.
    
    Parameters:
    -----------
    results : dict
        Dictionary with epsilon values as keys and (radii, wasserstein_distances, magnitude_distances) as values
    D : int
        Dimension (for filename)
    n : int
        Number of samples (for filename)
    t : float
        Scaling parameter (for filename)
    save_dir : str
        Directory to save plots
    """
    os.makedirs(save_dir, exist_ok=True)
    
    epsilon_values = sorted(results.keys())
    n_epsilon = len(epsilon_values)
    
    # Use consistent colors
    colors = plt.cm.tab10(range(n_epsilon * 2))
    
    # Create individual plots for each epsilon
    for idx, epsilon in enumerate(epsilon_values):
        radii, wasserstein_distances, magnitude_distances = results[epsilon]
        
        fig, ax = plt.subplots(figsize=(7, 6))
        
        # Use consistent colors for Wasserstein and Magnitude across different epsilons
        wass_color = colors[idx * 2]
        mag_color = colors[idx * 2 + 1]
        
        ax.plot(radii, wasserstein_distances, 'o-', label='Wasserstein', 
                linewidth=2, markersize=4, color=wass_color)
        ax.plot(radii, magnitude_distances, 's-', label='Magnitude', 
                linewidth=2, markersize=4, color=mag_color)
        
        # ax.set_xscale('log')
        ax.set_xlabel('Outlier Radius $R$')
        ax.set_ylabel('Distance')
        ax.set_title(rf'Outlier Sensitivity ($\epsilon = {epsilon}$)', fontweight='bold')
        
        leg = ax.legend(
            fontsize=14,
            loc='best',
            frameon=True,
            fancybox=True,
            shadow=True,
            framealpha=0.95,
        )
        leg.get_frame().set_facecolor('white')
        leg.get_frame().set_edgecolor('black')
        ax.grid(True, alpha=0.2)
        
        fname = f"{save_dir}/outlier_eps{epsilon}_D{D}_n{n}_t{t}"
        
        plt.savefig(f"{fname}.png", dpi=300, bbox_inches='tight')
        
        if USE_LATEX:
            try:
                plt.savefig(f"{fname}.pgf", bbox_inches='tight')
                print(f"Saved: {fname}.pgf", flush=True)
            except Exception as e:
                print(f"Could not save PGF (LaTeX issue): {e}", flush=True)
        
        plt.close()
        print(f"Saved: {fname}.png", flush=True)
    
    # Create combined plot with all epsilon values - separate columns for Wasserstein and Magnitude
    fig, ax = plt.subplots(figsize=(7, 6))
    
    # Use consistent colors for each epsilon value
    epsilon_colors = plt.cm.tab10(range(n_epsilon))
    
    # Plot all Wasserstein distances first (sorted by epsilon)
    for idx, epsilon in enumerate(epsilon_values):
        radii, wasserstein_distances, magnitude_distances = results[epsilon]
        ax.plot(radii, wasserstein_distances, 'o-', 
                label=f'Wass $\\epsilon={epsilon}$', 
                linewidth=2, markersize=4, color=epsilon_colors[idx])
    
    # Then plot all Magnitude distances (sorted by epsilon)
    for idx, epsilon in enumerate(epsilon_values):
        radii, wasserstein_distances, magnitude_distances = results[epsilon]
        ax.plot(radii, magnitude_distances, 's-', 
                label=f'Mag $\\epsilon={epsilon}$', 
                linewidth=2, markersize=4, color=epsilon_colors[idx])
    
    # ax.set_xscale('log')
    # ax.set_yscale('log')
    # ax.set_xlabel('Outlier Radius $R$ (log scale)')
    ax.set_xlabel('Outlier Radius $R$')

    # ax.set_ylabel('Distance (log scale)')
    ax.set_ylabel('Distance')
    ax.set_title('Outlier Sensitivity Comparison', fontweight='bold')

    # ax.axhline(
    # y=1,
    # color='black',
    # linestyle='--',
    # linewidth=1.5,
    # alpha=0.5,
    # label=r'$y=1$'
    # )
    
    leg = ax.legend(
        fontsize=14,
        ncol=2,
        loc='best',
        # loc='lower right',
        # bbox_to_anchor=(0.98, 0.02),
        frameon=True,
        fancybox=True,
        shadow=True,
        framealpha=0.95,
    )
    leg.get_frame().set_facecolor('white')
    leg.get_frame().set_edgecolor('black')
    ax.grid(True, alpha=0.2)

    # axins = inset_axes(ax, width="35%", height="35%", 
    #                loc='center left', 
    #                bbox_to_anchor=(0.05, 0, 1, 1),
    #                bbox_transform=ax.transAxes)

    # # Plot the same data in the inset, but only for the zoomed range
    # for idx, epsilon in enumerate(epsilon_values):
    #     radii, wasserstein_distances, magnitude_distances = results[epsilon]
    #     axins.plot(radii, wasserstein_distances, 'o-', 
    #             linewidth=2, markersize=3, color=epsilon_colors[idx])
    #     axins.plot(radii, magnitude_distances, 's-', 
    #             linewidth=2, markersize=3, color=epsilon_colors[idx])

    # # Set the zoom limits (first 100 dimensions)
    # axins.set_xlim(radii[0], radii[min(12, len(radii)-1)])  # Safe indexing
    # # Optionally adjust y-limits for better visibility
    # axins.set_ylim(-0.2, 5)  # Set based on your data

    # # Add grid to inset
    # axins.grid(True, alpha=0.2)

    # # Optional: Add a box around the zoomed region and connecting lines
    # mark_inset(ax, axins, loc1=2, loc2=4, fc="none", ec="0.5", linestyle='--')

    
    fname = f"{save_dir}/outlier_combined_D{D}_n{n}_t{t}"
    
    plt.savefig(f"{fname}.png", dpi=300, bbox_inches='tight')
    
    if USE_LATEX:
        try:
            plt.savefig(f"{fname}.pgf", bbox_inches='tight')
            print(f"Saved: {fname}.pgf", flush=True)
        except Exception as e:
            print(f"Could not save PGF (LaTeX issue): {e}", flush=True)
    
    plt.close()
    print(f"Saved: {fname}.png", flush=True)


def print_summary_statistics(results):
    """
    Print summary statistics of the experiment results for all epsilon values.
    """
    print("\n" + "="*60)
    print("SUMMARY STATISTICS")
    print("="*60)
    
    for epsilon in sorted(results.keys()):
        radii, wasserstein_distances, magnitude_distances = results[epsilon]
        
        print(f"\n{'='*60}")
        print(f"Epsilon = {epsilon}")
        print(f"{'='*60}")
        
        # Compute growth rates (ratio of max to min)
        w_growth = wasserstein_distances[-1] / wasserstein_distances[0]
        m_growth = magnitude_distances[-1] / magnitude_distances[0]
        
        print(f"\nWasserstein Distance:")
        print(f"  Min distance: {min(wasserstein_distances):.6f} (R={radii[np.argmin(wasserstein_distances)]})")
        print(f"  Max distance: {max(wasserstein_distances):.6f} (R={radii[np.argmax(wasserstein_distances)]})")
        print(f"  Growth rate: {w_growth:.2f}x")
        
        print(f"\nMagnitude Distance:")
        print(f"  Min distance: {min(magnitude_distances):.6f} (R={radii[np.argmin(magnitude_distances)]})")
        print(f"  Max distance: {max(magnitude_distances):.6f} (R={radii[np.argmax(magnitude_distances)]})")
        print(f"  Growth rate: {m_growth:.2f}x")
        
        print(f"\nRelative sensitivity:")
        print(f"  Magnitude/Wasserstein growth ratio: {m_growth/w_growth:.2f}")
    
    print("\n" + "="*60)

