import numpy as np
import pandas as pd
import torch
from glob import glob
from tqdm import tqdm
import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator
import seaborn as sns
import argparse
import os


def energy_distance_torch(X, Y, batch_size=2048, device=None, progress=False):
    """
    Compute the energy distance between two samples X and Y using GPU acceleration
    (CUDA or MPS) and batching for memory efficiency.
    """
    # Pick device
    if device is None:
        if torch.backends.mps.is_available():
            device = 'mps'
        elif torch.cuda.is_available():
            device = 'cuda'
        else:
            device = 'cpu'
    
    # Ensure 2D arrays
    X = np.atleast_2d(X)
    Y = np.atleast_2d(Y)
    
    X = torch.as_tensor(X, dtype=torch.float32, device=device)
    Y = torch.as_tensor(Y, dtype=torch.float32, device=device)

    def mean_pairwise_distance(A, B):
        total = 0.0
        count = 0
        iterator = tqdm(list(range(0, len(A), batch_size))) if progress else range(0, len(A), batch_size)
        for i in iterator:
            Ai = A[i:i + batch_size]
            dists = torch.cdist(Ai, B)
            total += dists.sum().item()
            count += dists.numel()
            del dists
            if device != 'cpu':
                torch.mps.empty_cache() if device == 'mps' else torch.cuda.empty_cache()
        return total / count

    E_xy = mean_pairwise_distance(X, Y)
    E_xx = mean_pairwise_distance(X, X)
    E_yy = mean_pairwise_distance(Y, Y)

    return 2 * E_xy - E_xx - E_yy


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--samples-dir', default='bnn/samples', type=str,
                        help='Directory containing NUTS and algorithm samples')
    parser.add_argument('--output-dir', default='bnn/img', type=str,
                        help='Directory to save output plots')
    parser.add_argument('--batch-size', default=8192, type=int,
                        help='Batch size for energy distance computation')
    parser.add_argument('--datasets', nargs='+', 
                        default=['boston.csv', 'combined.csv', 'concrete.arff', 'energy.csv', 'kin8nm.arff',
                                'naval.csv', 'protein.csv', 'wine.arff', 'yacht.csv', 'year.csv'],
                        help='List of dataset filenames to process')
    args = parser.parse_args()

    data_path_list = args.datasets
    
    plot_data = []

    algorithms = ['svgd', 'hsvgd', 'ssvgd']
    
    # Count total files to process for progress bar
    total_files = 0
    for data_path in data_path_list:
        dataset_name = data_path.split('.')[0]
        for algo in algorithms:
            svgd_sample_path = glob(os.path.join(args.samples_dir, algo, f'{dataset_name}*.txt'))
            total_files += len(svgd_sample_path)
    
    pbar = tqdm(total=total_files, desc='Computing energy distances')
    
    for data_path in data_path_list:
        dataset_name = data_path.split('.')[0]
        nuts_sample_path = glob(os.path.join(args.samples_dir, 'nuts', f'{dataset_name}*.txt'))
        
        if not nuts_sample_path:
            print(f"\nWarning: No NUTS samples found for {dataset_name}, skipping...")
            continue
        
        # Load NUTS samples and ensure 2D
        nuts_samples = np.loadtxt(nuts_sample_path[0])
        nuts_samples = np.atleast_2d(nuts_samples)
        print(f"\nLoaded NUTS samples for {dataset_name}: shape {nuts_samples.shape}")

        for algo in algorithms:
            svgd_sample_path = glob(os.path.join(args.samples_dir, algo, f'{dataset_name}*.txt'))
            
            if not svgd_sample_path:
                print(f"\nWarning: No {algo} samples found for {dataset_name}, skipping...")
                continue
            
            for path in svgd_sample_path:
                # Load samples and ensure 2D
                svgd_samples = np.loadtxt(path)
                svgd_samples = np.atleast_2d(svgd_samples)
                
                plot_data.append({
                    'Dataset': dataset_name,
                    'Scheme': algo,
                    'Energy Distance': energy_distance_torch(
                        svgd_samples, nuts_samples, batch_size=args.batch_size
                    )
                })
                pbar.update(1)
    pbar.close()

    # Check if we have data
    if not plot_data:
        print("\nError: No data found. Please check:")
        print(f"  1. Samples directory exists: {args.samples_dir}")
        print(f"  2. Directory structure: {args.samples_dir}/{{nuts,svgd,hsvgd,ssvgd}}/")
        print(f"  3. Sample files exist with pattern: <dataset>*.txt")
        exit(1)

    # Create dataframe
    plot_df = pd.DataFrame(plot_data)
    print(f"\nFound {len(plot_df)} samples across {plot_df['Dataset'].nunique()} datasets")
    
    # Compute summary stats
    summary = (
        plot_df.groupby(['Dataset', 'Scheme'], as_index=False)['Energy Distance']
          .agg(['mean', 'std'])
          .reset_index()
    )

    # Set seaborn style
    sns.set_theme(style="ticks", rc={
        "axes.edgecolor": "black",
        "axes.linewidth": 1.5,
        "xtick.direction": "out",
        "ytick.direction": "out"
    })

    # Palette (Tableau blue/orange/green)
    palette = sns.color_palette("tab10")[:3]

    # Figure setup
    fig, ax = plt.subplots(figsize=(10, 4))

    # Markers and labels
    markers = ['o', 'D', 's']
    legend_map = {
        'svgd': 'SVGD',
        'hsvgd': 'h-SVGD',
        'ssvgd': 'SSVGD'
    }

    # Define datasets and schemes in order
    datasets = np.array([d.split('.')[0].capitalize() for d in data_path_list])
    # Filter to only include datasets we actually have
    available_datasets = summary['Dataset'].str.capitalize().unique()
    datasets = np.array([d for d in datasets if d in available_datasets])
    
    x = np.arange(len(datasets))
    offset = 0.1

    scheme_list = ['svgd', 'hsvgd', 'ssvgd']
    scheme_offsets = np.linspace(-offset, offset, len(scheme_list))

    # Plot each scheme with error bars
    for i, scheme in enumerate(scheme_list):
        data_scheme = summary[summary['Scheme'] == scheme]
        if len(data_scheme) == 0:
            continue
        
        x_positions = []
        means = []
        stds = []
        
        for _, row in data_scheme.iterrows():
            dataset_cap = row['Dataset'].capitalize()
            if dataset_cap in datasets:
                idx = np.where(datasets == dataset_cap)[0][0]
                x_positions.append(idx + scheme_offsets[i])
                means.append(row['mean'])
                stds.append(row['std'])
        
        if x_positions:
            ax.errorbar(
                x_positions,
                means,
                yerr=stds,
                fmt=markers[i],
                color=palette[i],
                markersize=3,
                capsize=3,
                linestyle='none',
                alpha=0.9,
                label=legend_map.get(scheme, scheme)
            )

    # Axes labels, legend, and ticks
    ax.set_ylabel("Energy Distance", fontsize=12)
    ax.set_xticks(x)
    ax.set_xticklabels(datasets, rotation=0)
    ax.yaxis.set_major_locator(MultipleLocator(5))

    # Keep black borders on all sides
    sns.despine(ax=ax, top=False, right=False, left=False, bottom=False)

    # Legend inside the plot box
    ax.legend(
        title="Scheme",
        facecolor='white',
    )

    # Save and show
    os.makedirs(args.output_dir, exist_ok=True)
    output_path = os.path.join(args.output_dir, 'energy-dist.png')
    plt.tight_layout()
    plt.savefig(output_path, bbox_inches='tight', dpi=600)
    print(f"\nPlot saved to {output_path}")
    plt.show()