import warnings
from collections import namedtuple

import numpy as np
import torch
import torch.nn.functional as F
from scipy.optimize import linear_sum_assignment
from scipy.stats import normaltest
from sklearn.decomposition import PCA, FastICA
from torcheval.metrics.functional import r2_score


@torch.no_grad()
def partial_correlation(
    z_pred_mean: torch.Tensor,
    z_pred_log_std: torch.Tensor,
    n_groups: int | None = None,
    n_monte_carlo: int = 1,
    seed: int = 0,
) -> torch.Tensor:
    """Evaluate the partial correlation of the whole dataset.

    Parameters
    ----------
    z_pred_mean : torch.Tensor of shape (n_points, n_components)
        The predicted mean of the latent variable.
    z_pred_log_std : torch.Tensor of shape (n_points, n_components)
        The predicted log standard deviation of the latent variable.
    n_groups : int | None, optional
        Number of groups, by default None.
    n_monte_carlo : int, optional
        Number of Monte Carlo samples, by default 1.
    seed : int, optional
        Random seed, by default 0.

    Returns
    -------
    torch.Tensor
        The estimated partial correlation.

    Raises
    ------
    ValueError
        If group_rank = n_components / n_groups is not an integer.
    """
    n_points, n_components = z_pred_mean.shape
    if n_groups is None:
        n_groups = n_components

    if n_components % n_groups != 0:
        raise ValueError("group_rank = n_components / n_groups is not an integer.")
    group_rank = int(n_components / n_groups)

    generator = torch.Generator().manual_seed(seed)

    z_sampled = (
        torch.randn((n_monte_carlo, n_points, n_components), generator=generator)
        * z_pred_log_std.exp()
        + z_pred_mean
    )
    mat_ln_q_z = -F.gaussian_nll_loss(
        z_pred_mean.view((1, n_points, n_groups, group_rank)),
        z_sampled.view((n_monte_carlo * n_points, 1, n_groups, group_rank)),
        (z_pred_log_std.exp() ** 2).view((1, n_points, n_groups, group_rank)),
        full=True,
        reduction="none",
    )  # (n_monte_carlo*n_points, n_points, n_groups, group_rank)

    ln_q_z = torch.logsumexp(mat_ln_q_z.sum(dim=(2, 3)), dim=1) - np.log(n_points)
    ln_prod_q_zi = (
        torch.logsumexp(mat_ln_q_z.sum(dim=3), dim=1) - np.log(n_points)
    ).sum(dim=1)
    return (ln_q_z - ln_prod_q_zi).mean()


PCAICACheckResult = namedtuple(
    "PCAICACheckResult", ["pca", "ica", "ica_warning", "ica_normal_test"]
)


@torch.no_grad()
def pca_ica_check(
    z_pred_mean: torch.Tensor,
    n_groups: int,
    n_monte_carlo: int = 1,
    seed: int = 0,
) -> PCAICACheckResult:
    """Check the results by PCA and ICA.

    Parameters
    ----------
    z_pred_mean : torch.Tensor of shape (n_points, n_components)
        The predicted mean of the latent variable.
    z_pred_log_std : torch.Tensor of shape (n_points, n_components)
        The predicted log standard deviation of the latent variable.
    n_groups : int
        Number of groups.
    n_monte_carlo : int, optional
        Number of Monte Carlo samples, by default 1.
    seed : int, optional
        Random seed, by default 0.

    Returns
    -------
    pca : NDArray of shape (n_groups, group_rank)
        The explained variance ratio by PCA.
    ica : NDArray of shape (n_groups,)
        The within-group total correlation after ICA.
    ica_warning : NDArray of shape (n_groups,)
        The warning flag for ICA.
    ica_normal_test : NDArray of shape (n_groups, group_rank)
        The p-value of the normality test after ICA.
    """

    n_points, n_components = z_pred_mean.shape
    group_rank = int(n_components / n_groups)

    # PCA
    pca_result = np.zeros((n_groups, group_rank))
    for group in range(n_groups):
        pca = PCA(n_components=group_rank)
        pca.fit(z_pred_mean[:, group * group_rank : (group + 1) * group_rank].numpy())
        pca_result[group] = pca.explained_variance_ratio_

    # ICA
    ica_result = np.zeros(n_groups)
    ica_warning = np.zeros(n_groups, dtype=bool)
    ica_normal_test = np.zeros((n_groups, group_rank))
    for group in range(n_groups):
        fastica = FastICA(n_components=group_rank, random_state=seed)
        with warnings.catch_warnings(record=True) as caught_warnings:
            y = fastica.fit_transform(
                z_pred_mean[:, group * group_rank : (group + 1) * group_rank].numpy()
            )
        for warn in caught_warnings:
            if "FastICA did not converge" in str(warn.message):
                ica_warning[group] = True
        y_tensor = torch.from_numpy(y).to(torch.float32)
        bandwidth = 1 * n_points ** (-1 / (4 + group_rank))
        ica_result[group] = partial_correlation(
            y_tensor,
            np.log(bandwidth) * torch.ones_like(y_tensor),
            n_monte_carlo=n_monte_carlo,
            seed=seed,
        )
        ica_normal_test[group] = normaltest(y, axis=0).pvalue

    return PCAICACheckResult(
        pca_result,
        ica_result,
        ica_warning,
        ica_normal_test,
    )


@torch.no_grad()
def reduce_group_rank(z_pred, n_groups: int = 1, target_group_rank: int = 1):
    n_samples, n_components = z_pred.shape

    group_rank = int(n_components / n_groups)

    z_pred_reduced = torch.zeros((n_samples, n_groups * target_group_rank))
    for group in range(n_groups):
        pca = PCA(n_components=target_group_rank)
        z_pred_reduced[
            :, group * target_group_rank : (group + 1) * target_group_rank
        ] = torch.from_numpy(
            pca.fit_transform(
                (z_pred[:, group * group_rank : (group + 1) * group_rank]).numpy()
            )
        ).to(
            torch.float32
        )
    return z_pred_reduced


@torch.no_grad()
def align(
    z_pred: torch.Tensor, z_true: torch.Tensor, n_groups: int = 1
) -> torch.Tensor:
    """Align the predicted latent variable with the true latent variable.

    Parameters
    ----------
    z_pred : torch.Tensor of shape (n_samples, n_pred_components)
        The predicted latent variable.
    z_true : torch.Tensor of shape (n_samples, n_true_components)
        The true latent variable.
    n_groups : int, optional
        Number of groups, by default 1.

    Returns
    -------
    aligned_z_pred : torch.Tensor of shape (n_samples, n_components)
        The aligned predicted latent variable.
    """
    n_samples, n_true_components = z_true.shape
    _, n_components = z_pred.shape
    group_rank = int(n_components / n_groups)
    true_group_rank = int(n_true_components / n_groups)

    # if true_group_rank < group_rank:
    #     z_pred = reduce_group_rank(
    #         z_pred, n_groups=n_groups, target_group_rank=true_group_rank
    #     )

    aligned_z_all = torch.zeros((n_groups, n_groups, n_samples, true_group_rank))
    r2_matrix = torch.zeros((n_groups, n_groups))
    for true_group in range(n_groups):
        for pred_group in range(n_groups):
            z_pred_group = z_pred.view(n_samples, n_groups, -1)[:, pred_group, :]
            z_true_group = z_true.view(n_samples, n_groups, -1)[:, true_group, :]
            z_aug = torch.cat([torch.ones(n_samples, 1), z_pred_group], dim=1)
            wtsaffine = torch.linalg.lstsq(z_aug, z_true_group).solution
            aligned_z_all[true_group, pred_group] = z_aug @ wtsaffine
            r2_matrix[true_group, pred_group] = r2_score(
                aligned_z_all[true_group, pred_group], z_true_group
            )

    row_ind, col_ind = linear_sum_assignment(-r2_matrix)
    aligned_z_pred = (
        aligned_z_all[row_ind, col_ind]
        .transpose(0, 1)
        .reshape(n_samples, n_true_components)
    )
    return aligned_z_pred


@torch.no_grad()
def pick_align(z_pred: torch.Tensor, z_true: torch.Tensor) -> torch.Tensor:
    """Pick some dimensions from z_pred and align to the true latent variable.

    Parameters
    ----------
    z_pred : torch.Tensor of shape (n_samples, n_pred_components)
        The predicted latent variable.
    z_true : torch.Tensor of shape (n_samples, n_true_components)
        The true latent variable.

    Returns
    -------
    aligned_z_pred : torch.Tensor of shape (n_samples, n_components)
        The aligned predicted latent variable.
    """

    n_pred_components = z_pred.shape[1]
    n_true_components = z_true.shape[1]
    if n_pred_components < n_true_components:
        raise ValueError("n_pred_components < n_true_components")

    generator = torch.Generator().manual_seed(0)

    z_true_aug = torch.concat(
        [
            z_true,
            1e-3
            * torch.randn(
                z_true.shape[0],
                n_pred_components - n_true_components,
                generator=generator,
            ),
        ],
        dim=1,
    )

    return align(z_pred, z_true_aug, n_groups=n_pred_components)[:, :n_true_components]
