import jax.numpy as jnp
import jax
import wandb
import os
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.stats import gaussian_kde
from algorithms.common.ipm_eval import discrepancies
from orbax.checkpoint import PyTreeCheckpointer
from algorithms.fab.utils.plot import plot_contours_2D
from sklearn.cluster import KMeans
from jax.scipy.special import logsumexp
from algorithms.scld.resampling import log_effective_sample_size
from algorithms.common.eval_methods.utils import avg_stddiv_across_marginals


def subtrajectories_to_single_traj(subtrajs):
    # Input: a (S, B, L+1, D) array representing subtrajectories (end of one subtraj = start of next)
    # S = num subtrajs, B = batch size, L = len subtraj, D = target dim
    # Output: a (B, 1+S*L, D) array w/ subtrajs concatenated into single trajectory
    return jnp.concat([subtrajs[c][:,(1 if c > 0 else 0):,:] for c in range(subtrajs.shape[0])], axis=1)

# for cluster analysis
def ELBOW_plots(X):
    plt.close()
    fig, ax = plt.subplots()
    max_k = 70
    wcss = []
    for k in range(1, max_k+1, 2):
        kmeans = KMeans(n_clusters=k, random_state=42)
        kmeans.fit(X)
        wcss.append(kmeans.inertia_)  # Sum of squared distances to closest cluster center
        
    # Plot the Elbow Method
    ax.plot(range(1, max_k+1, 2), wcss, marker='o')
    ax.set_title('Elbow Method For analyzing number of modes')
    ax.set_xlabel('Number of clusters (k)')
    ax.set_ylabel('Within-cluster Sum of Squares (WCSS)')
    ax.set_title(f'ELBOW plot')
    return {"figures/vis": [wandb.Image(fig)]}  

# for studying behavior of subtrajectories
# Potentially: plot effect of pre and post SMC operation

def approximate_marginal_density(samples, range, ax=None, fig=None, dim1 = 0, dim2=1):
    if ax is None:
        fig, ax = plt.subplots()

    # Extract the x and y coordinates
    x, y = samples[:, dim1], samples[:, dim2]

    # Perform kernel density estimation
    kde = gaussian_kde(np.vstack([x, y]))

    xmin, xmax = x.min() - 1, x.max() + 1
    ymin, ymax = y.min() - 1, y.max() + 1

    if range is not None:
        xmin, xmax = -range, range
        ymin, ymax = -range, range

    # Create a grid over the sample space
    xx, yy = np.mgrid[xmin:xmax:100j, ymin:ymax:100j]
    positions = np.vstack([xx.ravel(), yy.ravel()])
    density = kde(positions).reshape(xx.shape) # (100,100) array

    # reverse x axis so that low array indices correspond to negative coordinates
    ax.imshow(np.transpose(density), origin="lower", extent=[xmin, xmax, ymin, ymax], cmap="rocket")
    ax.set_xlabel(f'dim{dim1}')
    ax.set_ylabel(f'dim{dim2}')
    return fig, ax

def plot_subtrajectory_marginal_samples(ndim, num_steps, beta_annealing_fn, params, log_prob,
             log_prob_prior, subtrajs, plot_bounds = 8, groundtruth_samples = None, rnds = None, plot_more_marginals = True):
    # if have groundtruth samples, use it to approximate the marginal density
    # assume num_steps is nice multiple of 9-1 = 8    

    temperatures = None
    if rnds is not None:
        # do temperature plot of things in subtrajectory
        # plot in logspace for visualization
        # we don't plot marginal heatmaps by logscale though
        temperatures = jnp.cumsum(rnds, axis=0)
    
    traj = subtrajectories_to_single_traj(subtrajs)
    num_marginals = min(num_steps, 9)
    times = [num_steps * c // (num_marginals-1) for c in range(num_marginals)]
    
    # for the background we will plot the 2d cross section with all but first
    # two coordinates fixed to 0. This makes sense in the cases of any 2d dist
    # funnel and ManyWell

    def log_prob_cross_section(x, beta, dim1 = 0, dim2 = 1):
        # x a (N,2) array    
        x_padded = jnp.zeros((x.shape[0], ndim))
        x_padded = x_padded.at[:, dim1].set(x[:,0])
        x_padded = x_padded.at[:, dim2].set(x[:,1])
        return beta * log_prob(x_padded) + (1 - beta) * log_prob_prior(x_padded)
    
    plt.close()

    plots_per_frame = 2 if plot_more_marginals else 1 
    size_multiplier = 1.5 if plot_more_marginals else 1
    fig,ax = plt.subplots(3 * plots_per_frame,3 *plots_per_frame, figsize=(16 * size_multiplier, 12 * size_multiplier))
    
    for t in range(num_marginals):
        plt.close()
        beta = beta_annealing_fn(times[t])
        
        for i in range(plots_per_frame):
            for j in range(plots_per_frame):
                dim1, dim2 = plots_per_frame*(plots_per_frame*i + j), plots_per_frame*(plots_per_frame*i + j) + 1
                ax_x, ax_y = plots_per_frame*(t//3) + i, plots_per_frame*(t % 3) + j

                #if groundtruth_samples is None:
                plot_contours_2D(lambda x: log_prob_cross_section(x, beta, dim1, dim2), ax=ax[ax_x][ax_y], bound = plot_bounds, levels=50, clip_threshold=None)
                #else:
                #    approximate_marginal_density(groundtruth_samples, plot_bounds, ax[ax_x][ax_y], fig, dim1, dim2)
                if t>0 and (rnds is not None) and t % ((num_marginals - 1) // len(rnds)) == 0:
                    temp_ind = (t * len(rnds)) // (num_marginals - 1) 
                    scatter = ax[ax_x][ax_y].scatter(traj[:,times[t],dim1], traj[:,times[t], dim2], c=temperatures[temp_ind-1,:], cmap='viridis')
                    fig.colorbar(scatter, ax=ax[ax_x][ax_y])
                else:
                    ax[ax_x][ax_y].scatter(traj[:,times[t],dim1], traj[:,times[t], dim2])
                ax[ax_x][ax_y].set_title(f"t={times[t]},beta={beta:.2f}")
                ax[ax_x][ax_y].set_xlabel(f"dim{dim1}")
                ax[ax_x][ax_y].set_ylabel(f"dim{dim2}")

    return {"figures/vis": [wandb.Image(fig)]}  

def plot_hist(data):
    plt.close()
    fig,ax = plt.subplots()
    ax.hist(data, bins = 40)
    return {"figures/vis": [wandb.Image(fig)]}

def visualise_betas(params, beta_fn, num_steps):
    # visualise the annealing schedule if it is learnt
    plt.close()
    b = [beta_fn(i) for i in range(num_steps+1)]
    fig,ax = plt.subplots()
    ax.plot(b)
    return {"figures/vis": [wandb.Image(fig)]}

def plot_ess(log_ess_data):
    # Input: The raw log_ess_data from simulate function
    # Should be of format ([logess_0,logess_1 ...], [pre_resample_ess_1, pre_resample_ess_2 ...])
    plt.close()
    series = []
    for c in range(len(log_ess_data[0])):
        series += [log_ess_data[0][c], log_ess_data[1][c]]
    
    fig,ax = plt.subplots(2, figsize=(10,6))
    ax[0].plot(series)
    ax[0].set_xlabel("alternate subtraj + resampling scheme")
    ax[0].set_ylabel("log(ESS)")
    ax[1].hist(log_ess_data[3][0], bins = 40)
    ax[1].set_title("logWeights after 1st subtrajectory")
    return {"figures/vis": [wandb.Image(fig)]}

def compute_model_logess(rnds):
    # note ess is between 1 & Batchsize so logESS is in [0, log B]
    cumulative_rnds = jnp.cumsum(rnds, axis=0)
    return [log_effective_sample_size(cumulative_rnds[i,:]) for i in range(len(rnds))]

def kde_visualise(samples, target_samples, plot_range = None, temperatures = None, plot_more_marginals = True):
    plt.close()
    plot_more_marginals = plot_more_marginals and (samples.shape[1] >= 8)
    M = 2 if plot_more_marginals else 1

    fig,ax = plt.subplots(2,2, figsize=(10,8)) if plot_more_marginals else plt.subplots()
    

    for i in range(M):
        for j in range(M):
            dim1, dim2 = 2*(M*i + j), 2*(M*i + j)+1

            # draw the approx ground truth marginals
            approximate_marginal_density(target_samples, plot_range, ax[i][j], fig, dim1, dim2)

            # scatter points
            if temperatures is None:
                ax[i][j].scatter(samples[:,dim1], samples[:, dim2])
            else:
                scatter = ax[i][j].scatter(samples[:,dim1], samples[:, dim2], c=temperatures, cmap='viridis')
                scatter = fig.colorbar(scatter, ax=ax[i][j])
    
    return {"figures/vis": [wandb.Image(fig)]}  

def eval_scld(simulate, 
              simulate_fn_list,
              probflow_ode_fn,
              get_schedule_and_prior_fn,
              target,
              target_samples,
              config):

    # simulate fn rolls out without any smc
    # simulate_smc rolls out as dictated by the use_{resampling, markov}_inference settings 
    # dictated in the config

    # memory for moving averages
    # not very functional programming honestly
    history = {
        'model_lnZ': [], 
        'smc_lnZ' : [],
        'model_ELBO': [], 
        'smc_ELBO' : [],
        'model_sd': [], 
        'smc_sd' : [],
        'smc_lnZ_actual': [],
        'ess_model' : [],
        'delta_mean_marginal_std_model': [],
        'delta_mean_marginal_std_smc': [],
    }

    moving_averages = {
        'model_lnZ': [], 
        'smc_lnZ' : [],
        'model_ELBO': [], 
        'smc_ELBO' : [],
        'model_sd': [], 
        'smc_sd' : [],
        'smc_lnZ_actual': [],
        'ess_model' : [],
        'delta_mean_marginal_std_model': [],
        'delta_mean_marginal_std_smc': [],
    }

    trackers_multisubtraj = {

    }

    best_multisubtraj = {

    }

    time_optimal_reached = {}

    simulate_smc = simulate_fn_list[-1][1]

    for (num_subtraj, _) in simulate_fn_list:
        trackers_multisubtraj[num_subtraj] = {'elbo': [], 'sd': []}
        best_multisubtraj[num_subtraj] = {'elbo': -jnp.inf, 'sd': jnp.inf}
    def short_eval(model_state, params, key, it_num):
        logger = {}
        # reusing keys here is probably okay
        is_finished = it_num + 1 == config.algorithm.n_sim

        if len(simulate_fn_list) > 1:
            for (n_sub_traj, sim_fn) in simulate_fn_list:
                samples_all, _, (lnz_est, elbo_est), (_, _), _, _ = sim_fn(key, model_state, params)
                samples = samples_all[-1]
                if target_samples is not None:
                    trackers_multisubtraj[n_sub_traj]['sd'].append(getattr(discrepancies, f'compute_sd')(target_samples, samples, config))
                else:
                    trackers_multisubtraj[n_sub_traj]['sd'].append(jnp.inf)    
                trackers_multisubtraj[n_sub_traj]['elbo'].append(elbo_est)

                best_multisubtraj[n_sub_traj]['elbo'] = max(best_multisubtraj[n_sub_traj]['elbo'], jnp.array(trackers_multisubtraj[n_sub_traj]['elbo'][-5:]).mean())
                best_multisubtraj[n_sub_traj]['sd'] = min(best_multisubtraj[n_sub_traj]['sd'], jnp.array(trackers_multisubtraj[n_sub_traj]['sd'][-5:]).mean())

                logger[f'VaryNSubtraj/{n_sub_traj}_ELBO'] = elbo_est
                logger[f'VaryNSubtraj/{n_sub_traj}_sd'] = trackers_multisubtraj[n_sub_traj]['sd'][-1]
                logger[f'VaryNSubtraj/{n_sub_traj}_ELBO_ma_max'] = best_multisubtraj[n_sub_traj]['elbo']
                logger[f'VaryNSubtraj/{n_sub_traj}_sd_ma_min'] = best_multisubtraj[n_sub_traj]['sd']

        model_samples_all, _, (model_lnz_est, model_elbo_est), (_, per_subtraj_rnds), subtrajs_model, log_ess_model = simulate(key, model_state, params)
        smc_samples_all, _, (smc_lnz_est, smc_elbo_est),  (_, per_subtraj_rnds_smc), subtrajs_smc, log_ess_smc = simulate_smc(key, model_state, params)
        total_rnds = per_subtraj_rnds.sum(axis=0)
        total_rnds_smc = per_subtraj_rnds_smc.sum(axis=0)

        model_samples = model_samples_all[-1]
        smc_samples = smc_samples_all[-1]
        assert(model_samples.shape[0] == config.eval_samples)
        # work on probflow ODE based metrics
        ode_samples, ode_sample_logprobs, ode_log_prob_on_target = (None, None, None) if not config.algorithm.plot_ode else probflow_ode_fn(key, model_state, params)
        
        _, initial_density, beta_fn, _ = get_schedule_and_prior_fn(params)
        
        if config.algorithm.plot_ode:
            logger['metric/mean_ode_targetsample_logprobs'] = ode_log_prob_on_target.mean() 
            logger['metric/ode_lnZ_estimate'] = logsumexp(target.log_prob(ode_samples)-ode_sample_logprobs) - jnp.log(config.eval_samples)
            logger['metric/ode_ELBO'] = (target.log_prob(ode_samples)-ode_sample_logprobs).mean()

        if hasattr(config.algorithm, "prior"):
            dim = model_samples.shape[-1]
            for j in range(min(dim,5)):
                logger[f'prior/dim_{j}_mean'] = params['params']['prior_mean'][j]
                logger[f'prior/dim_{j}_std'] = jnp.exp(params['params']['prior_log_stds'][j])
        logger[f'prior/max_diffusion'] = jnp.exp(params['params']['log_max_diffusion'])
        logger['metric/model_lnZ'] = model_lnz_est
        logger['metric/smc_lnZ'] = smc_lnz_est
        logger['metric/model_ess'] = jnp.exp(log_effective_sample_size(total_rnds))/config.eval_samples


        # the (principled) estimator of lnZ from SMC literature. resampling messes the usual estimator up
        logger['metric/smc_lnZ_actual'] = logsumexp(per_subtraj_rnds_smc - jnp.log(config.eval_samples), axis=1).sum()

        if target.log_Z is not None:
            logger['metric/model_delta_lnZ'] = jnp.abs(model_lnz_est - target.log_Z)
            logger['metric/smc_delta_lnZ'] = jnp.abs(smc_lnz_est - target.log_Z)
            
            if config.algorithm.plot_ode:
                logger['metric/ode_delta_lnZ'] = jnp.abs(logger['metric/ode_lnZ_estimate']  - target.log_Z)

        logger['metric/model_ELBO'] = model_elbo_est
        logger['metric/smc_ELBO'] = smc_elbo_est
        logger['metric/model_target_llh'] = jnp.mean(target.log_prob(smc_samples))
        logger['metric/smc_target_llh'] = jnp.mean(target.log_prob(model_samples))

        if config.compute_emc and config.target.has_entropy:
            logger['metric/model_entropy'] = target.entropy(model_samples)
            logger['metric/smc_entropy'] = target.entropy(smc_samples)

        for d in config.discrepancies:

            logger[f'discrepancies/model_{d}'] = getattr(discrepancies, f'compute_{d}')(target_samples, model_samples,
                                                                                        config) if target_samples is not None else jnp.inf

            logger[f'discrepancies/smc_{d}'] = getattr(discrepancies, f'compute_{d}')(target_samples, smc_samples,
                                                                                      config) if target_samples is not None else jnp.inf       
            if config.algorithm.plot_ode: 
                logger[f'discrepancies/ode_{d}'] = getattr(discrepancies, f'compute_{d}')(target_samples, ode_samples,
                                                                                      config) if target_samples is not None else jnp.inf
        # plot samples
        #if not config.algorithm.sweep_mode:
        #    logger['model_samples'] = target.visualise(model_samples, show=config.visualize_samples, temperatures = jnp.exp(jax.nn.log_softmax(total_rnds)))
        #elif is_finished:
        #    # some distributions may not have temperature setting implemented
        #    logger['model_samples'] = target.visualise(model_samples, show=config.visualize_samples)
        logger['model_samples'] = target.visualise(model_samples, show=config.visualize_samples)
        logger['model_samples_smc'] = target.visualise(smc_samples, show=config.visualize_samples)
        
        # plot samples against a kde approximation of target density
        if target_samples is not None and not config.algorithm.sweep_mode:
            logger['model_samples_kde'] = kde_visualise(model_samples, target_samples,  temperatures = jnp.exp(jax.nn.log_softmax(total_rnds)))
            logger['model_samples_smc_kde'] = kde_visualise(smc_samples, target_samples)
        
        if config.algorithm.plot_ode:
            logger['model_samples_ode'] = target.visualise(ode_samples, show=config.visualize_samples, temperatures = jnp.exp(jax.nn.log_softmax(ode_sample_logprobs)))
        
        if target_samples is not None and is_finished:
            logger['groundtruth_samples'] = target.visualise(target_samples, show=config.visualize_samples)    
        
        # log histograms of rnds: The closer to true lnZ the better
        logger['rnds/model_logrnds'] = plot_hist(total_rnds)
        logger['rnds/model_logrnds_smc'] = plot_hist(total_rnds_smc)
        
        # log the annealing schedule
        if config.algorithm.annealing_schedule.schedule_type == "learnt":
            if (not config.algorithm.sweep_mode) or is_finished:
                logger['other/learnt_annealing_schedule'] = visualise_betas(params, beta_fn, config.algorithm.num_steps)

        if config.algorithm.loss in ['rev_tb', 'fwd_tb']:
            logger['other/sum_log_Z_second_moment'] = jnp.sum(params['params']['logZ_deltas'])

            # this one only matters if doing subtraj
            logger['other/final_log_Z_delta_second_moment'] = params['params']['logZ_deltas'][-1]

        if config.algorithm.plot_subtrajs:
            subtraj_plot_scale = 8 if not hasattr(config.target, "plot_range") else config.target.plot_range
            logger[f'other/marginals_and_smc_setting_samples'] = plot_subtrajectory_marginal_samples(
                config.target.dim, config.algorithm.num_steps, beta_fn, params, target.log_prob, initial_density.log_prob, subtrajs_smc, subtraj_plot_scale, groundtruth_samples=target_samples)
            
            logger[f'other/marginals_and_model_setting_samples'] = plot_subtrajectory_marginal_samples(
                config.target.dim, config.algorithm.num_steps, beta_fn, params, target.log_prob, initial_density.log_prob, subtrajs_model, subtraj_plot_scale, rnds = per_subtraj_rnds, groundtruth_samples=target_samples)

        logger['metric/delta_mean_marginal_std_model'] = jnp.abs(avg_stddiv_across_marginals(model_samples) - target.marginal_std)
        logger['metric/delta_mean_marginal_std_smc'] = jnp.abs(avg_stddiv_across_marginals(smc_samples) - target.marginal_std)

        #logger[f"modes/Kmeans_ELBOW_plot_model"] = ELBOW_plots(model_samples)
        #logger[f"modes/Kmeans_ELBOW_plot_smc"] = ELBOW_plots(smc_samples)
        logger["modes/log_ess_model"] = plot_ess(log_ess_model)
        logger["modes/log_ess_smc"] = plot_ess(log_ess_smc)
        
        subtraj_ess_values = compute_model_logess(per_subtraj_rnds)
        logger["ess/model_ess"] = subtraj_ess_values[-1]
        logger['ess/model_ess_prefix'] = subtraj_ess_values # ESS values of the marginals at each subtraj endpoint
        
        history['model_lnZ'].append(model_lnz_est)
        history['smc_lnZ'].append(smc_lnz_est)
        history['model_ELBO'].append(model_elbo_est)
        history['smc_ELBO'].append(smc_elbo_est)
        history['model_sd'].append(logger['discrepancies/model_sd'])
        history['smc_sd'].append(logger['discrepancies/smc_sd'])
        history['smc_lnZ_actual'].append(logger['metric/smc_lnZ_actual'])
        history['ess_model'].append(logger['metric/model_ess'])
        history['delta_mean_marginal_std_model'].append(logger['metric/delta_mean_marginal_std_model'])
        history['delta_mean_marginal_std_smc'].append(logger['metric/delta_mean_marginal_std_smc'])

        moving_average_width = 5 if not hasattr(config.algorithm, "ma_length") else config.algorithm.ma_length

        for (key, value) in history.items():
            logger[f"MovingAverages/{key}"] = sum(value[-moving_average_width:])/len(value[-moving_average_width:])
            moving_averages[key].append(logger[f"MovingAverages/{key}"])

        for (key, array) in moving_averages.items():
            # inefficient but whatever
            if len(array) > moving_average_width:
                logger[f"model_selection/{key}_MovingAverage_MAX"] = max(array[moving_average_width:])
                logger[f"model_selection/{key}_MovingAverage_MIN"] = min(array[moving_average_width:])

                if logger[f"model_selection/{key}_MovingAverage_MAX"] == array[moving_average_width:][-1]:
                    time_optimal_reached[f"model_selection/{key}_MovingAverage_MAX"] = it_num 
                if logger[f"model_selection/{key}_MovingAverage_MIN"] == array[moving_average_width:][-1]:
                    time_optimal_reached[f"model_selection/{key}_MovingAverage_MIN"] = it_num 
                
                    if "sd" in key:
                        if 'smc' in key:
                            logger[f"model_selection/delta_mean_marginal_std_smc_ma_at_optimal_smc_sd"] = moving_averages['delta_mean_marginal_std_smc'][-1]
                        else:
                            logger[f"model_selection/delta_mean_marginal_std_model_ma_at_optimal_model_sd"] = moving_averages['delta_mean_marginal_std_model'][-1]
                            


        if is_finished:
            for (key, value) in time_optimal_reached.items():
                logger[f"Optimal_time_{key}"] = value 
        
        if is_finished and not config.algorithm.sweep_mode:
            # Define the directory to save the checkpoint
            checkpoint_dir = os.path.join(wandb.run.dir, "model_checkpoint")

            # Create a PyTreeCheckpointer instance
            checkpointer = PyTreeCheckpointer()

            checkpointer.save(checkpoint_dir, model_state)

            final_model_trajs = subtrajectories_to_single_traj(subtrajs_model)
            final_smc_trajs = subtrajectories_to_single_traj(subtrajs_smc)
            
            jnp.save(os.path.join(wandb.run.dir, "final_model_trajs"), final_model_trajs)
            jnp.save(os.path.join(wandb.run.dir, "final_smc_trajs"), final_smc_trajs)
        
        return logger

    return short_eval
