import torch
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
import sys, os, json, glob
from datetime import datetime
from scipy.stats import wasserstein_distance

# Set PATH for LaTeX FIRST
os.environ['PATH'] = '/Library/TeX/texbin:' + os.environ.get('PATH', '')

# Apply style BEFORE LaTeX settings
plt.style.use('seaborn-v0_8-darkgrid')

# Now set LaTeX configuration (this will override style settings)
plt.rcParams.update({
    'text.usetex': True,
    'pgf.texsystem': 'pdflatex',
    'font.family': 'serif',
    'pgf.rcfonts': False,
    'font.size': 28,
    'axes.labelsize': 20,
    'axes.titlesize': 22,
    'legend.fontsize': 14,
    'xtick.labelsize': 20,
    'ytick.labelsize': 20,
    # White background settings
    'figure.facecolor': 'white',
    'axes.facecolor': 'white',
    'savefig.facecolor': 'white',
    'axes.edgecolor': 'black',
    'axes.linewidth': 1.0,
    'grid.color': 'gray',
    'grid.alpha': 0.3,
    'grid.linestyle': '-',
    'grid.linewidth': 0.5,
})

USE_LATEX = True

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from magnitude import norm_diff_magnitude_distance_grad

def gaussian_kernel(X, Y, sigma=1.0):
    XX = torch.sum(X**2, dim=1, keepdim=True)
    YY = torch.sum(Y**2, dim=1, keepdim=True).t()
    XY = X @ Y.t()
    return torch.exp(-(XX + YY - 2 * XY) / (2 * sigma**2))

def mmd_gaussian(X, Y, sigma=1.0):
    n, m = len(X), len(Y)
    K_XX, K_YY, K_XY = gaussian_kernel(X, X, sigma), gaussian_kernel(Y, Y, sigma), gaussian_kernel(X, Y, sigma)
    mmd_sq = (K_XX.sum() - torch.trace(K_XX)) / (n * (n - 1))
    mmd_sq += (K_YY.sum() - torch.trace(K_YY)) / (m * (m - 1))
    mmd_sq -= 2 * K_XY.sum() / (n * m)
    return torch.sqrt(torch.clamp(mmd_sq, min=0))

def wasserstein_nd(X, Y, n_projections=50):
    """Sliced Wasserstein distance for high-D"""
    d = X.shape[1]
    dists = []
    for _ in range(n_projections):
        theta = torch.randn(d, device=X.device)
        theta = theta / torch.norm(theta)
        X_proj = (X @ theta).cpu().numpy()
        Y_proj = (Y @ theta).cpu().numpy()
        dists.append(wasserstein_distance(X_proj, Y_proj))
    return np.mean(dists)

def experiment(device, dims, t_values, sigma_values, n_trials=30, n_samples=500, mean_difference=0.0):
    mmd_stats_dict = {sigma: {'mean': [], 'std': [], 'all_trials': []} for sigma in sigma_values}
    mag_stats_dict = {t: {'mean': [], 'std': [], 'all_trials': []} for t in t_values}
    wass_stats = {'mean': [], 'std': [], 'all_trials': []}
    
    print(f"Running experiments... Device: {device}, Mean difference: {mean_difference}", flush=True)
    
    for d in dims:
        mmd_trials_dict = {sigma: [] for sigma in sigma_values}
        mag_trials_dict = {t: [] for t in t_values}
        wass_trials = []
        scaled_mean_diff = mean_difference / (d ** 0.5)
        
        for trial in range(n_trials):
            X = torch.randn(n_samples, d, device=device)
            Y = torch.randn(n_samples, d, device=device) + scaled_mean_diff
            
            # Compute MMD with different sigmas
            for sigma in sigma_values:
                actual_sigma = 1.0 / (d**0.5) if sigma == 'sqrt_auto' else sigma
                mmd = mmd_gaussian(X, Y, sigma=actual_sigma)
                mmd_trials_dict[sigma].append(mmd.item())
            
            # Compute Wasserstein
            wass = wasserstein_nd(X, Y)
            wass_trials.append(wass)
            
            # Compute magnitude distance
            for t in t_values:
                actual_t = 1.0 / d if t == 'auto' else (1.0 / (d**0.5) if t == 'sqrt_auto' else t)
                mag = norm_diff_magnitude_distance_grad(X, Y, device=str(device), t=actual_t, normalize=True, eps=0)
                mag_trials_dict[t].append(mag)
        
        # Compute statistics for MMD
        for sigma in sigma_values:
            mmd_mean, mmd_std = torch.tensor(mmd_trials_dict[sigma]).mean().item(), torch.tensor(mmd_trials_dict[sigma]).std().item()
            mmd_stats_dict[sigma]['mean'].append(mmd_mean)
            mmd_stats_dict[sigma]['std'].append(mmd_std)
            mmd_stats_dict[sigma]['all_trials'].append(mmd_trials_dict[sigma])
        
        # Compute statistics for Wasserstein
        wass_mean, wass_std = np.mean(wass_trials), np.std(wass_trials)
        wass_stats['mean'].append(wass_mean)
        wass_stats['std'].append(wass_std)
        wass_stats['all_trials'].append(wass_trials)
        
        # Compute statistics for magnitude distances
        for t in t_values:
            mag_mean, mag_std = torch.tensor(mag_trials_dict[t]).mean().item(), torch.tensor(mag_trials_dict[t]).std().item()
            mag_stats_dict[t]['mean'].append(mag_mean)
            mag_stats_dict[t]['std'].append(mag_std)
            mag_stats_dict[t]['all_trials'].append(mag_trials_dict[t])
        
        # Save notes
        base_dir = f"experiments_meandiff_{mean_difference}"
        os.makedirs(base_dir, exist_ok=True)
        
        # Save MMD notes for each sigma
        for sigma in sigma_values:
            sigma_str = 'sqrt_auto' if sigma == 'sqrt_auto' else str(sigma)
            mmd_note = {"timestamp": datetime.utcnow().isoformat() + "Z", "dimension": d, "n_trials": n_trials,
                        "n_samples": n_samples, "sigma": sigma_str, "mean_difference": mean_difference,
                        "mmd_mean": mmd_stats_dict[sigma]['mean'][-1], "mmd_std": mmd_stats_dict[sigma]['std'][-1],
                        "mmd_trials": [float(x) for x in mmd_trials_dict[sigma]]}
            mmd_file = f"{base_dir}/mmd_sigma{sigma_str}_n{n_samples}_trials{n_trials}_dim{d}.json"
            with open(mmd_file, "w") as f:
                json.dump(mmd_note, f, indent=2)
        
        # Save Wasserstein note
        wass_note = {"timestamp": datetime.utcnow().isoformat() + "Z", "dimension": d, "n_trials": n_trials,
                     "n_samples": n_samples, "mean_difference": mean_difference,
                     "wass_mean": wass_mean, "wass_std": wass_std, "wass_trials": [float(x) for x in wass_trials]}
        wass_file = f"{base_dir}/wass_n{n_samples}_trials{n_trials}_dim{d}.json"
        with open(wass_file, "w") as f:
            json.dump(wass_note, f, indent=2)
        
        # Save magnitude notes
        for t in t_values:
            t_str = 'auto' if t == 'auto' else ('sqrt_auto' if t == 'sqrt_auto' else str(t))
            mag_note = {"timestamp": datetime.utcnow().isoformat() + "Z", "dimension": d, "n_trials": n_trials,
                       "n_samples": n_samples, "t": t_str, "mean_difference": mean_difference,
                       "mag_mean": mag_stats_dict[t]['mean'][-1], "mag_std": mag_stats_dict[t]['std'][-1],
                       "mag_trials": [float(x) for x in mag_trials_dict[t]]}
            mag_file = f"{base_dir}/mag_t{t_str}_n{n_samples}_trials{n_trials}_dim{d}.json"
            with open(mag_file, "w") as f:
                json.dump(mag_note, f, indent=2)
        
        print(f"Dim {d}: Wass={wass_mean:.6f}±{wass_std:.6f}", flush=True)
    
    return mmd_stats_dict, wass_stats, mag_stats_dict

def plot_single(dims, mmd_stats_dict, wass_stats, mag_stats_dict, n_trials, n_samples, sigma_values, mean_difference):
    plots = [
        ('mean', 'Mean Distance vs Dimension', 'Mean Distance', lambda s: s['mean']),
        ('mean_log', 'Mean Distance vs Dimension (Log)', 'Mean Distance', lambda s: s['mean']),
        ('std', 'Std Deviation vs Dimension', 'Standard Deviation', lambda s: s['std']),
        ('cv', 'Relative Variability vs Dimension', 'Coefficient of Variation', 
         lambda s: (torch.tensor(s['std']) / torch.tensor(s['mean'])).numpy()),
        ('errorbar', r'Mean Distance $\pm$ Std Dev', 'Distance', None)
    ]
    
    base_dir = f"experiments_meandiff_{mean_difference}/plots"
    os.makedirs(base_dir, exist_ok=True)
    
    t_values = list(mag_stats_dict.keys())
    n_metrics = len(sigma_values) + 1 + len(t_values)  # MMD variants + Wass + Mag variants
    colors = plt.cm.tab10(range(n_metrics))
    
    for plot_type, title, ylabel, data_fn in plots:
        fig, ax = plt.subplots(figsize=(7, 6))
        
        color_idx = 0
        
        if plot_type == 'errorbar':
            # Plot MMD variants
            for sigma in sigma_values:
                label = r'MMD $\sigma=1/\sqrt{D}$' if sigma == 'sqrt_auto' else f'MMD $\\sigma={sigma}$'
                mmd_mean, mmd_std = np.array(mmd_stats_dict[sigma]['mean']), np.array(mmd_stats_dict[sigma]['std'])
                ax.plot(dims, mmd_mean, '-o', label=label, linewidth=2, markersize=4, color=colors[color_idx])
                ax.fill_between(dims, mmd_mean - mmd_std, mmd_mean + mmd_std, color=colors[color_idx], alpha=0.2)
                color_idx += 1
            
            # Plot Wasserstein
            wass_mean, wass_std = np.array(wass_stats['mean']), np.array(wass_stats['std'])
            ax.plot(dims, wass_mean, '-s', label='Wasserstein', linewidth=2, markersize=4, color=colors[color_idx])
            ax.fill_between(dims, wass_mean - wass_std, wass_mean + wass_std, color=colors[color_idx], alpha=0.2)
            color_idx += 1
            
            # Plot Magnitude distances
            for t in t_values:
                label = f'Mag $t={t}$' if t not in ['auto', 'sqrt_auto'] else (
                    r'Mag $t=1/D$' if t == 'auto' else r'Mag $t=1/\sqrt{D}$')
                mag_mean, mag_std = np.array(mag_stats_dict[t]['mean']), np.array(mag_stats_dict[t]['std'])
                ax.plot(dims, mag_mean, '-^', label=label, linewidth=2, markersize=4, color=colors[color_idx])
                ax.fill_between(dims, mag_mean - mag_std, mag_mean + mag_std, color=colors[color_idx], alpha=0.2)
                color_idx += 1

                
        elif plot_type == 'mean_log':
            # Plot MMD variants
            for sigma in sigma_values:
                label = r'MMD $\sigma=1/\sqrt{D}$' if sigma == 'sqrt_auto' else f'MMD $\\sigma={sigma}$'
                ax.semilogy(dims, data_fn(mmd_stats_dict[sigma]), 'o-', label=label, linewidth=2, markersize=4, color=colors[color_idx])
                color_idx += 1
            
            # Plot Wasserstein
            ax.semilogy(dims, data_fn(wass_stats), 's-', label='Wasserstein', linewidth=2, markersize=4, color=colors[color_idx])
            color_idx += 1
            
            # Plot Magnitude distances
            for t in t_values:
                label = f'Mag $t={t}$' if t not in ['auto', 'sqrt_auto'] else (
                    r'Mag $t=1/D$' if t == 'auto' else r'Mag $t=1/\sqrt{D}$')
                ax.semilogy(dims, data_fn(mag_stats_dict[t]), '^-', label=label, linewidth=2, markersize=4, color=colors[color_idx])
                color_idx += 1
            ax.set_yscale('log')
            ax.set_xscale('linear')
        else:
            # Plot MMD variants
            for sigma in sigma_values:
                label = r'MMD $\sigma=1/\sqrt{D}$' if sigma == 'sqrt_auto' else f'MMD $\\sigma={sigma}$'
                ax.plot(dims, data_fn(mmd_stats_dict[sigma]), 'o-', label=label, linewidth=2, markersize=4, color=colors[color_idx])
                color_idx += 1
            
            # Plot Wasserstein
            ax.plot(dims, data_fn(wass_stats), 's-', label='Wasserstein', linewidth=2, markersize=4, color=colors[color_idx])
            color_idx += 1
            
            # Plot Magnitude distances
            for t in t_values:
                label = f'Mag $t={t}$' if t not in ['auto', 'sqrt_auto'] else (
                    r'Mag $t=1/D$' if t == 'auto' else r'Mag $t=1/\sqrt{D}$')
                ax.plot(dims, data_fn(mag_stats_dict[t]), '^-', label=label, linewidth=2, markersize=4, color=colors[color_idx])
                color_idx += 1
            if plot_type == 'cv':
                ax.set_yscale('log')
        
        ax.set_xlabel('Dimension')
        ax.set_ylabel(ylabel)
        ax.set_title(title, fontweight='bold')
        leg = ax.legend(
            fontsize=14,
            ncol=2,
            loc='upper right',
            bbox_to_anchor=(0.98, 0.75),
            frameon=True,
            fancybox=True,
            shadow=True,
            framealpha=0.95,
        )

        leg.get_frame().set_facecolor('white')
        leg.get_frame().set_edgecolor('black')
        ax.grid(True, alpha=0.2)
        
        fname = f"{base_dir}/{plot_type}_dims{min(dims)}-{max(dims)}_n{n_samples}_trials{n_trials}"
        
        plt.savefig(f"{fname}.png", dpi=300, bbox_inches='tight')
        
        if USE_LATEX:
            try:
                plt.savefig(f"{fname}.pgf", bbox_inches='tight')
                print(f"Saved: {fname}.pgf", flush=True)
            except Exception as e:
                print(f"Could not save PGF (LaTeX issue): {e}", flush=True)
        
        plt.close()
        print(f"Saved: {fname}.png", flush=True)

def load_experiment_notes(base_dir):
    data_by_dim = {}
    
    for filepath in glob.glob(f"{base_dir}/*.json"):
        with open(filepath, "r") as f:
            note = json.load(f)
            d = note["dimension"]
            
            if d not in data_by_dim:
                data_by_dim[d] = {"mmd_trials": {}, "wass_trials": [], "mag_trials": {}}
            
            if "mmd_trials" in note:
                sigma = note["sigma"]
                # Convert string to proper type
                if sigma == 'sqrt_auto':
                    sigma = 'sqrt_auto'
                else:
                    sigma = float(sigma)
                if sigma not in data_by_dim[d]["mmd_trials"]:
                    data_by_dim[d]["mmd_trials"][sigma] = []
                data_by_dim[d]["mmd_trials"][sigma].extend(note["mmd_trials"])
            elif "wass_trials" in note:
                data_by_dim[d]["wass_trials"].extend(note["wass_trials"])
            elif "mag_trials" in note:
                t = note["t"]
                # Convert string to proper type
                if t == 'auto':
                    t = 'auto'
                elif t == 'sqrt_auto':
                    t = 'sqrt_auto'
                else:
                    t = float(t)
                if t not in data_by_dim[d]["mag_trials"]:
                    data_by_dim[d]["mag_trials"][t] = []
                data_by_dim[d]["mag_trials"][t].extend(note["mag_trials"])
    
    dims = sorted(data_by_dim.keys())
    mmd_stats_dict = {}
    wass_stats = {'mean': [], 'std': [], 'all_trials': []}
    mag_stats_dict = {}
    
    for d in dims:
        # MMD statistics
        for sigma, mmd_trials in data_by_dim[d]["mmd_trials"].items():
            if sigma not in mmd_stats_dict:
                mmd_stats_dict[sigma] = {'mean': [], 'std': [], 'all_trials': []}
            mmd_stats_dict[sigma]['mean'].append(torch.tensor(mmd_trials).mean().item())
            mmd_stats_dict[sigma]['std'].append(torch.tensor(mmd_trials).std().item())
            mmd_stats_dict[sigma]['all_trials'].append(mmd_trials)
        
        # Wasserstein statistics
        wass_trials = data_by_dim[d]["wass_trials"]
        if wass_trials:  # Only add if there are trials
            wass_stats['mean'].append(np.mean(wass_trials))
            wass_stats['std'].append(np.std(wass_trials))
            wass_stats['all_trials'].append(wass_trials)
        
        # Magnitude statistics
        for t, mag_trials in data_by_dim[d]["mag_trials"].items():
            if t not in mag_stats_dict:
                mag_stats_dict[t] = {'mean': [], 'std': [], 'all_trials': []}
            mag_stats_dict[t]['mean'].append(torch.tensor(mag_trials).mean().item())
            mag_stats_dict[t]['std'].append(torch.tensor(mag_trials).std().item())
            mag_stats_dict[t]['all_trials'].append(mag_trials)
    
    return dims, mmd_stats_dict, wass_stats, mag_stats_dict

