import numpy as np
import pandas as pd
import torch
import os
import random
from src.models.mvnmixture import MultivariateNormalMixture
from src.svgd import hSVGD, SSVGD

import seaborn as sns
import matplotlib.pyplot as plt
import argparse
from tqdm import tqdm


def plot_marginal_scatter_single_row(particles_dict, model, plot_dims, save_path, 
                                      contour_levels=20, grid_space=200, buffer=0.1):
    """
    Plot scatter plots of three methods (SVGD, h-SVGD, SSVGD) in a single row.
    
    Args:
        particles_dict: Dict with keys 'SVGD', 'h-SVGD', 'SSVGD' containing particle arrays
        model: The model to evaluate marginal log probability
        plot_dims: Tuple of two dimension indices to plot
        save_path: Path to save the figure
        contour_levels: Number of contour levels
        grid_space: Grid resolution for contour plots
        buffer: Buffer for axis limits
    """
    particles1 = particles_dict['SVGD']
    particles2 = particles_dict['h-SVGD']
    particles3 = particles_dict['SSVGD']
    
    # Extract coordinates
    x1 = particles1[:, plot_dims[0]]
    y1 = particles1[:, plot_dims[1]]
    x2 = particles2[:, plot_dims[0]]
    y2 = particles2[:, plot_dims[1]]
    x3 = particles3[:, plot_dims[0]]
    y3 = particles3[:, plot_dims[1]]
    
    # Determine shared x and y limits
    all_x = np.concatenate([x1, x2, x3])
    all_y = np.concatenate([y1, y2, y3])
    x_min, x_max = all_x.min(), all_x.max()
    y_min, y_max = all_y.min(), all_y.max()
    x_limits = (x_min - buffer, x_max + buffer)
    y_limits = (y_min - buffer, y_max + buffer)
    
    # Create figure
    fig, axes = plt.subplots(1, 3, figsize=(12, 4), sharex=True, sharey=True)
    
    # Create 2D grid
    xx, yy = np.meshgrid(np.linspace(x_limits[0], x_limits[1], grid_space),
                         np.linspace(y_limits[0], y_limits[1], grid_space))
    grid = np.stack([xx, yy], axis=-1)
    grid_flat = grid.reshape(-1, 2)
    grid_tensor = torch.tensor(grid_flat, dtype=torch.float32)
    
    # Evaluate the 2D marginal log-probability
    with torch.no_grad():
        log_probs_flat = model.marginal_log_prob(grid_tensor, dims=plot_dims)
    
    density = torch.exp(log_probs_flat).numpy().reshape(xx.shape)
    
    titles = ['SVGD', 'h-SVGD', 'S-SVGD']
    colors = ['tab:blue', 'tab:orange', 'tab:green']
    coords = [(x1, y1), (x2, y2), (x3, y3)]
    
    for ax, title, color, (x, y) in zip(axes, titles, colors, coords):
        ax.contour(xx, yy, density, levels=contour_levels, cmap='Greys', alpha=0.5)
        ax.scatter(x, y, color=color, s=10)
        ax.set_title(title)
        ax.set_xlim(x_limits)
        ax.set_ylim(y_limits)
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_xlabel('')
        ax.set_ylabel('')
    
    plt.tight_layout()
    plt.savefig(save_path, bbox_inches='tight', dpi=600)
    plt.close()


def plot_marginal_scatter_multi_row(particles_dict, model, plot_dims_list, save_path,
                                     contour_levels=20, grid_space=200, buffer=0.1):
    """
    Plot scatter plots of three methods across multiple rows for different dimension pairs.
    
    Args:
        particles_dict: Dict with keys 'SVGD', 'h-SVGD', 'SSVGD' containing particle arrays
        model: The model to evaluate marginal log probability
        plot_dims_list: List of tuples, each containing two dimension indices
        save_path: Path to save the figure
        contour_levels: Number of contour levels
        grid_space: Grid resolution for contour plots
        buffer: Buffer for axis limits
    """
    particles1 = particles_dict['SVGD']
    particles2 = particles_dict['h-SVGD']
    particles3 = particles_dict['SSVGD']
    
    num_rows = len(plot_dims_list)
    num_cols = 3
    fig, axes = plt.subplots(num_rows, num_cols, figsize=(12, 4 * num_rows), 
                             sharex=False, sharey=False)
    
    # Ensure axes is 2D even if num_rows == 1
    if num_rows == 1:
        axes = axes.reshape(1, -1)
    
    titles = ['SVGD', 'h-SVGD', 'S-SVGD']
    colors = ['tab:blue', 'tab:orange', 'tab:green']
    
    for row, plot_dims in enumerate(plot_dims_list):
        dim_x, dim_y = plot_dims
        
        # Extract coordinates for this dimension pair
        x1, y1 = particles1[:, dim_x], particles1[:, dim_y]
        x2, y2 = particles2[:, dim_x], particles2[:, dim_y]
        x3, y3 = particles3[:, dim_x], particles3[:, dim_y]
        
        # Determine shared limits
        all_x = np.concatenate([x1, x2, x3])
        all_y = np.concatenate([y1, y2, y3])
        x_min, x_max = all_x.min(), all_x.max()
        y_min, y_max = all_y.min(), all_y.max()
        x_limits = (x_min - buffer, x_max + buffer)
        y_limits = (y_min - buffer, y_max + buffer)
        
        # Create grid for contours
        xx, yy = np.meshgrid(
            np.linspace(x_limits[0], x_limits[1], grid_space),
            np.linspace(y_limits[0], y_limits[1], grid_space)
        )
        grid = np.stack([xx, yy], axis=-1)
        grid_flat = grid.reshape(-1, 2)
        grid_tensor = torch.tensor(grid_flat, dtype=torch.float32)
        
        # Evaluate model marginal log-probability
        with torch.no_grad():
            log_probs_flat = model.marginal_log_prob(grid_tensor, dims=plot_dims)
        
        density = torch.exp(log_probs_flat).numpy().reshape(xx.shape)
        
        # Plot each method in its column
        coords = [(x1, y1), (x2, y2), (x3, y3)]
        for col, (color, (x, y)) in enumerate(zip(colors, coords)):
            ax = axes[row, col]
            
            ax.contour(xx, yy, density, levels=contour_levels, cmap='Greys', alpha=0.5)
            ax.scatter(x, y, color=color, s=10)
            ax.set_xlim(x_limits)
            ax.set_ylim(y_limits)
            ax.set_xticks([])
            ax.set_yticks([])
            
            # Add titles only for the top row
            if row == 0:
                ax.set_title(titles[col], fontsize=14, weight='bold')
            
            # Label leftmost column with the dims used
            if col == 0:
                ax.set_ylabel(f'Dimensions {dim_x}, {dim_y}', fontsize=12)
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=600)
    plt.close()


if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('-v', '--initial-particle-variance', default=2, type=float)
    parser.add_argument('-r', '--random-seed', default=42, type=int)
    parser.add_argument('-e', '--num-exps', default=10)
    parser.add_argument('-n', '--num-particles', default=50)
    parser.add_argument('-i', '--num-iterations', default=2000)
    parser.add_argument('-s', '--step-size', default=0.01)
    parser.add_argument('--dim-min', default=100)
    parser.add_argument('--dim-max', default=1000)
    parser.add_argument('--dim-interval', default=100)
    args = parser.parse_args()

    plt.rcParams.update({'text.usetex': True})

    d_min = int(args.dim_min)
    d_max = int(args.dim_max)
    d_interval = int(args.dim_interval)
    d_list = [d for d in range(d_min, d_max + 1) if d%d_interval==0]

    K = 10 # Number of mixture components
    N = int(args.num_particles)
    num_iterations = int(args.num_iterations)
    num_exps = int(args.num_exps)
    eps = float(args.step_size)

    random_seed = int(args.random_seed)
    initial_particle_variance = float(args.initial_particle_variance)

    def run_hsvgd(particles, model, num_iterations, eps, scheme, k_1, k_2, **kwargs):
        svgd = hSVGD(particles, model, k_1, k_2)
        svgd.update(num_iterations, eps=eps, **kwargs)
        metrics_list.append({
            'Scheme': scheme, 'd': d, 'N': N, 'Time (sec)': svgd.time_seconds, 'DAMV': svgd.damv,
            'PARF-initial': svgd.parf_initial, 'PAKSG-initial': svgd.paksg_initial,
            'PARF-final': svgd.parf_final, 'PAKSG-final': svgd.paksg_final, 'KSD': svgd.ksd,
            'Energy Distance': model.energy_distance_mc(svgd.particles_current, n_mc=10000),
            'Particles': svgd.particles_current.clone()
        })

    def run_ssvgd(particles, model, num_iterations, eps, scheme, k, **kwargs):
        svgd = SSVGD(particles, model, k)
        svgd.update(num_iterations, eps=eps, n_g_update=1, **kwargs)
        metrics_list.append({
            'Scheme': scheme, 'd': d, 'N': N, 'Time (sec)': svgd.time_seconds, 'DAMV': svgd.damv,
            'PARF-initial': svgd.parf_initial, 'PAKSG-initial': svgd.paksg_initial,
            'PARF-final': svgd.parf_final, 'PAKSG-final': svgd.paksg_final, 'KSD': svgd.ksd,
            'Energy Distance': model.energy_distance_mc(svgd.particles_current, n_mc=10000),
            'Particles': svgd.particles_current.clone()
        })

    def energy_distance(x, y):
        """
        Computes the (empirical) energy distance between two sets of samples.

        Args:
            x (tensor[N, D]): Samples from distribution P.
            y (tensor[M, D]): Samples from distribution Q.

        Returns:
            float: Energy distance between the two distributions.
        """
        x = torch.as_tensor(x, dtype=torch.float32)
        y = torch.as_tensor(y, dtype=torch.float32)
        n, m = x.size(0), y.size(0)

        # Pairwise Euclidean distances
        xy_dist = torch.cdist(x, y, p=2)       # [N, M]
        xx_dist = torch.cdist(x, x, p=2)       # [N, N]
        yy_dist = torch.cdist(y, y, p=2)       # [M, M]

        # Compute mean terms
        term1 = 2.0 * xy_dist.mean()
        term2 = xx_dist.mean()
        term3 = yy_dist.mean()

        energy = term1 - term2 - term3
        return energy.item()

    metrics_list = []

    for d in tqdm(d_list):

        if d > 300:
            num_exps = 1
            num_iterations = 1

        torch.manual_seed(random_seed)
        means = torch.stack([torch.randn(d) for _ in range(K)])
        covariances = torch.stack([torch.diag(torch.ones(d)) for _ in range(K)])
        weights = torch.ones(K) / K
        model = MultivariateNormalMixture(means, covariances, weights)
        model.scale_to_unit_marginal_variance()
            
        for exp in range(num_exps):

            torch.manual_seed(random_seed+exp)
            particles = torch.randn(N, d) * initial_particle_variance

            run_hsvgd(
                particles, model, num_iterations, eps,
                'SVGD',
                k_1={'family': 'rbf', 'weight': 1, 'bandwidth_factor': 1, 'preconditioning': None},
                k_2={'family': 'rbf', 'weight': 1, 'bandwidth_factor': 1, 'preconditioning': None}
            )

            run_hsvgd(
                particles, model, num_iterations, eps,
                'h-SVGD',
                k_1={'family': 'rbf', 'weight': 1, 'bandwidth_factor': 1, 'preconditioning': None},
                k_2={'family': 'rbf', 'weight': np.sqrt(d), 'bandwidth_factor': 1, 'preconditioning': None}
            )

            run_ssvgd(
                particles, model, num_iterations, eps,
                'SSVGD',
                k={'family': 'rbf', 'weight': 1, 'bandwidth_factor': 1, 'preconditioning': None}, adagrad=True, g_lr=0.0005
            )

    metrics_df = pd.DataFrame(metrics_list)

    avg_var = model.averaged_marginal_variance()

    os.makedirs('mvn', exist_ok=True)

    ### Plot metrics ###

    fig, ax = plt.subplots()
    sns.lineplot(metrics_df, x='d', y='DAMV', hue='Scheme', ax=ax, ci=None)
    sns.lineplot(x=d_list, y=np.array(d_list)*0+float(avg_var), linestyle='--', color='black', label='Target')
    plt.savefig(os.path.join('mvn', f'damv-numexp-{num_exps}-N-{N}-eps-{eps}-iter-{num_iterations}-dint-{d_interval}-K-{K}-initvar-{initial_particle_variance}.png'), dpi=300)
    plt.close()

    fig, ax = plt.subplots()
    sns.lineplot(metrics_df, x='d', y='Time (sec)', hue='Scheme', ax=ax, ci=None)
    ax.set_yscale('log')
    plt.savefig(os.path.join('mvn', f'time-{num_exps}-N-{N}-eps-{eps}-iter-{num_iterations}-dint-{d_interval}-K-{K}-initvar-{initial_particle_variance}.png'), dpi=300)
    plt.close()

    fig, ax = plt.subplots()
    sns.lineplot(metrics_df, x='d', y='Energy Distance', hue='Scheme', ax=ax, ci=None)
    plt.savefig(os.path.join('mvn', f'energy-dist-numexp-{num_exps}-N-{N}-eps-{eps}-iter-{num_iterations}-dint-{d_interval}-K-{K}-initvar-{initial_particle_variance}.png'), dpi=300)
    plt.close()

    ### Plot marginal scatter plots ###

    # Get the final particles for the last dimension
    particles_svgd = [m['Particles'] for m in metrics_list if m['Scheme'] == 'SVGD'][-1].numpy()
    particles_hsvgd = [m['Particles'] for m in metrics_list if m['Scheme'] == 'h-SVGD'][-1].numpy()
    particles_ssvgd = [m['Particles'] for m in metrics_list if m['Scheme'] == 'SSVGD'][-1].numpy()
    
    particles_dict = {
        'SVGD': particles_svgd,
        'h-SVGD': particles_hsvgd,
        'SSVGD': particles_ssvgd
    }
    
    # Plot for dimensions 0 and 1
    plot_marginal_scatter_single_row(
        particles_dict, model, plot_dims=(0, 1),
        save_path=os.path.join('mvn', f'marginal-scatter-dims-0-1-N-{N}-d-{d}-K-{K}-iter-{num_iterations}.png')
    )
    
    # Plot for 5 randomly selected dimension pairs
    random.seed(random_seed)
    num_dims = particles_svgd.shape[1]
    plot_dims_list = [tuple(random.sample(range(num_dims), 2)) for _ in range(5)]
    
    plot_marginal_scatter_multi_row(
        particles_dict, model, plot_dims_list,
        save_path=os.path.join('mvn', f'marginal-scatter-5dims-N-{N}-d-{d}-K-{K}-iter-{num_iterations}.png')
    )