import math
from typing import Union

import numpy as np
import torch

from typing import Optional
from .mmd import linear_mmd2, mix_rbf_mmd2, poly_mmd2
# from .optimal_transport import wasserstein
import ot as pot
from functools import partial
import scipy.stats as stats
from scipy.stats import pearsonr, spearmanr
import pandas as pd
from sklearn.metrics.pairwise import rbf_kernel

def wasserstein(
    x0: torch.Tensor,
    x1: torch.Tensor,
    method: Optional[str] = None,
    reg: float = 0.05,
    power: int = 2,
    **kwargs,
) -> float:
    assert power == 1 or power == 2
    # ot_fn should take (a, b, M) as arguments where a, b are marginals and
    # M is a cost matrix
    if method == "exact" or method is None:
        ot_fn = pot.emd2
    elif method == "sinkhorn":
        ot_fn = partial(pot.sinkhorn2, reg=reg)
    else:
        raise ValueError(f"Unknown method: {method}")

    a, b = pot.unif(x0.shape[0]), pot.unif(x1.shape[0])
    if x0.dim() > 2:
        x0 = x0.reshape(x0.shape[0], -1)
    if x1.dim() > 2:
        x1 = x1.reshape(x1.shape[0], -1)
    M = torch.cdist(x0, x1)
    if power == 2:
        M = M**2
    ret = ot_fn(a, b, M.detach().cpu().numpy(), numItermax=1e7)
    if power == 2:
        ret = math.sqrt(ret)
    return ret

def mmd_distance(x, y, gamma):
    xx = rbf_kernel(x, x, gamma)
    xy = rbf_kernel(x, y, gamma)
    yy = rbf_kernel(y, y, gamma)

    return xx.mean() + yy.mean() - 2 * xy.mean()


def compute_scalar_mmd(target, transport, gammas=None):
    if gammas is None:
        gammas = [2, 1, 0.5, 0.1, 0.01, 0.005]

    def safe_mmd(*args):
        try:
            mmd = mmd_distance(*args)
        except ValueError:
            mmd = np.nan
        return mmd

    return np.mean(list(map(lambda x: safe_mmd(target, transport, x), gammas)))


def compute_distances(pred, true):
    """computes distances between vectors."""
    mse = torch.nn.functional.mse_loss(pred, true).item()
    me = math.sqrt(mse)
    return mse, me, torch.nn.functional.l1_loss(pred, true).item()

def compute_distribution_distances(pred, true):
    """computes distances between distributions.
    pred: [batch, times, dims] tensor
    true: [batch, times, dims] tensor or list[batch[i], dims] of length times

    This handles jagged times as a list of tensors.
    return the eval for the last time point
    """
    NAMES = [
        "1-Wasserstein",
        "2-Wasserstein",
        #"Linear_MMD",
        #"Poly_MMD",
        "RBF_MMD",
        "Mean_MSE",
        "Mean_L2",
        "Mean_L1",
        "Median_MSE",
        "Median_L2",
        "Median_L1",
    ]
    is_jagged = isinstance(true, list)
    pred_is_jagged = isinstance(pred, list)
    dists = []
    to_return = []
    names = []
    filtered_names = [name for name in NAMES if not is_jagged or not name.endswith("MMD")]
    ts = len(pred) if pred_is_jagged else pred.shape[1]
    # for t in np.arange(ts):
    t = max(ts - 1, 0)
    if pred_is_jagged:
        a = pred[t]
    else:
        a = pred[:, t, :]
    if is_jagged:
        b = true[t]
    else:
        b = true[:, t, :]
    w1 = wasserstein(a, b, power=1)
    w2 = wasserstein(a, b, power=2)
    if not pred_is_jagged and not is_jagged:
        #mmd_linear = linear_mmd2(a, b).item()
        #mmd_poly = poly_mmd2(a, b, d=2, alpha=1.0, c=2.0).item()
        #mmd_rbf = mix_rbf_mmd2(a, b, sigma_list=[0.01, 0.1, 1, 10, 100]).item()
        mmd_rbf = compute_scalar_mmd(b.cpu().numpy(), a.cpu().numpy())
    mean_dists = compute_distances(torch.mean(a, dim=0), torch.mean(b, dim=0))
    median_dists = compute_distances(torch.median(a, dim=0)[0], torch.median(b, dim=0)[0])
    if pred_is_jagged or is_jagged:
        dists.append((w1, w2, *mean_dists, *median_dists))
    else:
        #dists.append((w1, w2, mmd_linear, mmd_poly, mmd_rbf, *mean_dists, *median_dists))
        dists.append((w1, w2, mmd_rbf, *mean_dists, *median_dists))
    # For multipoint datasets add timepoint specific distances
    # if ts > 1:
    # names.extend([f"t{t+1}/{name}" for name in filtered_names])
    # to_return.extend(dists[-1])
    # print("returning")
    to_return.extend(np.array(dists).mean(axis=0))
    names.extend(filtered_names)
    return names, to_return

def compute_distribution_distances_new(pred: torch.Tensor, true: Union[torch.Tensor, list]):
    """computes distances between distributions.
    pred: [batch, times, dims] tensor
    true: [batch, times, dims] tensor or list[batch[i], dims] of length times

    This handles jagged times as a list of tensors.
    return the eval for the last time point
    """
    NAMES = [
        "1-Wasserstein",
        "2-Wasserstein",
        "Linear_MMD",
        "Poly_MMD",
        "RBF_MMD",
        "Mean_MSE",
        "Mean_L2",
        "Mean_L1",
        "Median_MSE",
        "Median_L2",
        "Median_L1",
    ]
    is_jagged = isinstance(true, list)
    pred_is_jagged = isinstance(pred, list)
    dists = []
    to_return = []
    names = []
    filtered_names = [name for name in NAMES if not is_jagged or not name.endswith("MMD")]
    ts = len(pred) if pred_is_jagged else pred.shape[1]
    # for t in np.arange(ts):
    t = max(ts - 1, 0)
    if pred_is_jagged:
        a = pred[t]
    else:
        a = torch.tensor(pred).float()
    if is_jagged:
        b = true[t]
    else:
        b = torch.tensor(true).float()
    w1 = wasserstein(a, b, power=1)
    w2 = wasserstein(a, b, power=2)

    if not pred_is_jagged and not is_jagged:
        mmd_linear = linear_mmd2(a, b).item()
        mmd_poly = poly_mmd2(a, b, d=2, alpha=1.0, c=2.0).item()
        mmd_rbf = mix_rbf_mmd2(a, b, sigma_list=[0.01, 0.1, 1, 10, 100]).item()
    mean_dists = compute_distances(torch.mean(a, dim=0), torch.mean(b, dim=0))
    median_dists = compute_distances(torch.median(a, dim=0)[0], torch.median(b, dim=0)[0])

    if pred_is_jagged or is_jagged:
        dists.append((w1, w2, *mean_dists, *median_dists))
    else:
        dists.append((w1, w2, mmd_linear, mmd_poly, mmd_rbf, *mean_dists, *median_dists))

    to_return.extend(np.array(dists).mean(axis=0))
    names.extend(filtered_names)
    return names, to_return


def r2_score2(pred, true):
    """computes r2 score between numpy array."""
    return 1 - np.sum(np.square(pred - true)) / np.sum(np.square(true - np.mean(true)))
    

def r2_feature(pred, true):
    """compute r2 per feature (assume vector is n x d).
        Then average over features.
    """
    return np.mean([r2_score2(pred[:, i], true[:, i]) for i in range(pred.shape[1])])

def r2_cell(pred, true):
    """compute r2 per cell (assume vector is n x d).
        Then average over cells.
    """
    return np.mean([r2_score2(pred[i, :], true[i, :]) for i in range(pred.shape[0])])


# pearson and spearman correlation

def pearson_corr(pred, true):
    """computes pearson correlation between numpy array."""
    return pearsonr(pred, true)[0]

def spearman_corr(pred, true):
    """computes spearman correlation between numpy array."""
    return spearmanr(pred, true)[0]

def pearson_corr_feature(pred, true):
    """compute pearson correlation per feature (assume vector is n x d).
        Then average over features.
    """
    return np.mean([pearson_corr(pred[:, i], true[:, i]) for i in range(pred.shape[1])])


def pearson_corr_cell(pred, true):
    """compute pearson correlation per cell (assume vector is n x d).
        Then average over cells.
    """
    return np.mean([pearson_corr(pred[i, :], true[i, :]) for i in range(pred.shape[0])])


def spearman_corr_feature(pred, true):
    """compute spearman correlation per feature (assume vector is n x d).
        Then average over features.
    """
    return np.mean([spearman_corr(pred[:, i], true[:, i]) for i in range(pred.shape[1])])

def spearman_corr_cell(pred, true):
    """compute spearman correlation per cell (assume vector is n x d).
        Then average over cells.
    """
    return np.mean([spearman_corr(pred[i, :], true[i, :]) for i in range(pred.shape[0])])

def new_pearson_corr(pred, true):
    """computes pearson correlation between numpy array."""
    #return np.corrcoef(data, rowvar=False) # TODO fix. No "data" input?
    pass


"""
from cellot

MMD calculation
corr calculation

"""

import numpy as np
from sklearn.metrics.pairwise import rbf_kernel


def cellot_mmd_distance(x, y, gamma):
    xx = rbf_kernel(x, x, gamma)
    xy = rbf_kernel(x, y, gamma)
    yy = rbf_kernel(y, y, gamma)

    return xx.mean() + yy.mean() - 2 * xy.mean()


def cellot_compute_scalar_mmd(target, transport, gammas=None):
    if gammas is None:
        gammas = [2, 1, 0.5, 0.1, 0.01, 0.005]

    def safe_mmd(*args):
        try:
            mmd = cellot_mmd_distance(*args)
        except ValueError:
            mmd = np.nan
        return mmd

    return np.mean(list(map(lambda x: safe_mmd(target, transport, x), gammas)))


def compute_pairwise_corrs(df):
    corr = df.corr().rename_axis(index='lhs', columns='rhs')
    return (
        corr
        .where(np.triu(np.ones(corr.shape), k=1).astype(bool))
        .stack()
        .reset_index()
        .set_index(['lhs', 'rhs'])
        .squeeze()
    )

def cellot_corr(pred, ground_truth):
    # gammas = np.logspace(1, -3, num=50)
    # ncells, nfeatures = pred.shape

    mut, mui = ground_truth.mean(0), pred.mean(0)
    stdt, stdi = ground_truth.std(0), pred.std(0)
    pwct = compute_pairwise_corrs(pd.DataFrame(pred))
    pwci = compute_pairwise_corrs(pd.DataFrame(ground_truth))

    # calculate dictionary of scores
    score_dict = {}
    score_dict['l2-means'] = np.linalg.norm(mut - mui)
    score_dict['l2-stds'] = np.linalg.norm(stdt - stdi)
    score_dict['r2-means'] = pd.Series(mut).corr(pd.Series(mui))
    score_dict['r2-stds'] = pd.Series(stdt).corr(pd.Series(stdi))
    score_dict['r2-pairwise_feat_corrs'] = pd.Series(pwct).corr(pd.Series(pwci))
    score_dict['l2-pairwise_feat_corrs'] = np.linalg.norm(pwct - pwci)

    return score_dict