from sklearn.mixture import GaussianMixture
import numpy as np
import torch
from scipy.stats import chi2, multivariate_normal
from scipy.special import softmax
from scipy.optimize import linear_sum_assignment

from utils import logmeanexp
from sample_metrics import compute_distribution_distances

import matplotlib.pyplot as plt
import wandb


@torch.no_grad()
def get_forward_trajectory_metrics(initial_states, gfn, energy, lambda_discretizer):
    with torch.no_grad():
        states, log_pfs, log_pbs, log_fs = gfn.get_trajectory_fwd(initial_states, lambda_discretizer, None, energy.log_reward)
    if isinstance(states, list):
        log_r = energy.log_reward([states[0][:, -1], states[1]])
    else:
        log_r = energy.log_reward(states[:, -1])
    log_weight = log_r + log_pbs.sum(-1) - log_pfs.sum(-1)

    log_Z = logmeanexp(log_weight) - energy.gt_logz
    log_Z_elbo = log_weight.mean() - energy.gt_logz
    log_Z_learned = log_fs[:, 0].mean() - energy.gt_logz

    if isinstance(states, list):
        return states[0][:, -1], log_Z, log_Z_elbo, log_Z_learned
    else:
        return states[:, -1], log_Z, log_Z_elbo, log_Z_learned


def get_backward_trajectory_metrics(initial_states, gfn, energy, lambda_discretizer, num_evals=10):
    bsz = initial_states.shape[0]
    eval_data = initial_states.unsqueeze(1).repeat(1, num_evals, 1).view(bsz * num_evals, -1)
    with torch.no_grad():
        states, log_pfs, log_pbs, log_fs = gfn.get_trajectory_bwd(eval_data, lambda_discretizer, energy.log_reward)
    log_r = energy.log_reward(states[:, -1])
    log_weight = (log_pfs.sum(-1) - log_pbs.sum(-1)).view(bsz, num_evals, -1)
    eubo = torch.mean(log_r + log_pbs.sum(-1) - log_pfs.sum(-1)) - energy.gt_logz
    return logmeanexp(log_weight, dim=1).mean(), eubo


def is_within_confidence_region(sample, mean, covariance_matrix, confidence_level=0.99):
    # Mahalanobis distance calculation
    diff = sample - mean
    mahalanobis_distance_squared = diff @ torch.linalg.inv(covariance_matrix) @ diff

    # Chi-squared threshold
    degrees_of_freedom = mean.shape[0]
    chi_squared_threshold = chi2.ppf(confidence_level, degrees_of_freedom)

    return mahalanobis_distance_squared <= chi_squared_threshold


def gmm_metrics(samples, energy, n_trials=10):
    nmode = energy.nmode
    statistics = []
    statistics_w_known_params = []

    for random_state in range(n_trials):
        x = samples.detach().cpu().numpy()
        gm = GaussianMixture(n_components=nmode, random_state=random_state).fit(x)
        gm_w_known_params = GaussianMixture(n_components=nmode, random_state=random_state)
        gm_w_known_params._initialize_parameters(x, None)

        for _ in range(100):
            gm_w_known_params.means_ = energy.means.cpu().numpy()
            gm_w_known_params.covariances_ = energy.means.cpu().numpy()
            _, log_resp = gm_w_known_params._e_step(x)
            gm_w_known_params._m_step(x, log_resp)

        used_modes = np.zeros(nmode)
        l1_means_distances, l1_covariances_dists = np.empty((nmode, nmode)), np.empty((nmode, nmode))
        l1_means_naive, l1_covariances_naive = 0, 0
        for i in range(nmode):
            l1_means_distances[i] = np.abs(energy.means[i].cpu().numpy() - gm.means_).mean(1)
            for j in range(nmode):
                assert (
                    gm.covariances_[j].shape == energy.covariance_matrices[i].shape
                ), f"{gm.covariances_[j].shape=}, {energy.covariance_matrices[i].shape=}"
                l1_covariances_dists[i, j] = np.abs(np.diag(gm.covariances_[j]) - energy.covariance_matrices[i].cpu().numpy()).mean()
            nearest_mode_idx = np.argmin(l1_means_distances[i])
            used_modes[nearest_mode_idx] = 1
            l1_means_naive += l1_means_distances[i, nearest_mode_idx] / nmode
            l1_covariances_naive += l1_covariances_dists[i, nearest_mode_idx] / nmode

        row_idxs, col_idxs = linear_sum_assignment(l1_means_distances)
        gm_weights = gm.weights_[col_idxs]
        l1_means_mincost = l1_means_distances[row_idxs, col_idxs].mean()
        l1_covariances_mincost = l1_covariances_dists[row_idxs, col_idxs].mean()

        unique_match = used_modes.sum() == nmode
        l1_weight = np.abs(gm_weights - energy.gmm_weights).mean()
        l1_weight_w_known_params = np.abs(gm_w_known_params.weights_ - energy.gmm_weights).mean()

        # assert np.isclose(
        #     gm.weights_, gm.predict_proba(samples.cpu().numpy()).mean(0), atol=1e-3
        # ).all(), f"{gm.weights_=}, {gm.predict_proba(samples.cpu().numpy()).mean(0)=}"

        samples_in99interval_cnt = np.zeros(nmode)
        for i in range(nmode):
            for sample in samples:
                samples_in99interval_cnt[i] += is_within_confidence_region(sample, energy.means[i], energy.covariance_matrices[i])
        # print(gm_w_known_params.weights_, gm_weights)
        statistics.append(
            [
                l1_means_mincost,
                l1_covariances_mincost,
                l1_means_naive,
                l1_covariances_naive,
                l1_weight,
                np.array(unique_match, dtype=np.float64),
                np.array(gm_weights),
                np.array(samples_in99interval_cnt),
                gm,
            ]
        )
        statistics_w_known_params.append([l1_weight_w_known_params, gm_w_known_params.weights_])

    # print(samples.shape, l1_means, l1_covariances, l1_weight, gm.weights_, unique_match, samples_in99interval_cnt)

    return min(statistics, key=lambda stat: stat[0]) + min(statistics_w_known_params, key=lambda stat: stat[0])


def add_gaussian_metrics(samples, energy, metrics, wandb_mode=False):
    (
        metrics["eval/l1_means_mincost"],
        metrics["eval/l1_covariances_mincost"],
        metrics["eval/l1_means_naive"],
        metrics["eval/l1_covariances_naive"],
        metrics["eval/l1_weight"],
        metrics["eval/unique_match"],
        gaussian_weights,
        in99interval_cnt,
        gm,
        metrics["eval/l1_weight_w_known_params"],
        gaussian_weights_w_known_params,
    ) = gmm_metrics(samples, energy, n_trials=3)

    if wandb_mode:
        metrics["hist/gaussian_weights"] = wandb.Histogram(
            np_histogram=np.histogram(range(energy.nmode), bins=energy.nmode, weights=gaussian_weights, density=True)
        )
        fig, ax = plt.subplots()
        ax.hist(range(energy.nmode), bins=energy.nmode, weights=gaussian_weights, density=True)
        metrics["visualization/gaussian_weights"] = wandb.Image(fig)

        metrics["hist/gaussian_weights_w_known_params"] = wandb.Histogram(
            np_histogram=np.histogram(range(energy.nmode), bins=energy.nmode, weights=gaussian_weights_w_known_params, density=True)
        )
        fig, ax = plt.subplots()
        ax.hist(range(energy.nmode), bins=energy.nmode, weights=gaussian_weights_w_known_params, density=True)
        metrics["visualization/gaussian_weights_w_known_params"] = wandb.Image(fig)

        metrics["hist/in99interval_cnt"] = wandb.Histogram(
            np_histogram=np.histogram(range(energy.nmode), bins=energy.nmode, weights=in99interval_cnt, density=False)
        )
        fig, ax = plt.subplots()
        ax.hist(range(energy.nmode), bins=energy.nmode, weights=in99interval_cnt, density=False)
        metrics["visualization/in99interval_cnt"] = wandb.Image(fig)

    return metrics, gm


@torch.no_grad()
def get_sample_metrics(samples, gt_samples=None, final_eval=False):
    if gt_samples is None:
        return

    return compute_distribution_distances(samples.unsqueeze(1), gt_samples.unsqueeze(1), final_eval)
