from __future__ import print_function
import ot
from scipy.stats import entropy
import numpy as np
import scipy.linalg as ln
from numpy import linalg as LA
import torch
from torch.autograd import Variable
import src.models.generate_NN as g_NN
import src.datamodules.data_utils as DTU
import src.utils.pytorch_utils as PTU


def normalize(np_array):
    np_array /= np_array.sum()
    return np_array


# * ideal barycenter for Gaussian


# def ideal_mean_and_cov_originD(args, weights_distribution):
def ideal_gaussian_baryc_originD(args, weights_distribution):
    mean_ideal = []
    for i in range(args.NUM_DISTRIBUTION):
        mean_ideal.append(args.MEAN[i] * weights_distribution[i])
    mean_ideal_originD = sum(mean_ideal).mean(axis=0)
    Sn = np.asmatrix(np.eye(args.INPUT_DIM))

    # si represents the -1/2 power, the s means the 1/2 power and S represents S    matrix itself.
    num_itr = 0
    while True:
        num_itr += 1
        s = ln.sqrtm(Sn)
        si = LA.inv(s)
        ans_medium = np.asmatrix(np.zeros_like(Sn))
        for i in range(args.NUM_DISTRIBUTION):
            ans_medium += weights_distribution[i] * ln.sqrtm(
                np.matmul(np.matmul(s, np.asmatrix(args.COV[i][0, :, :])), s)
            )
        Sn_1 = np.matmul(ans_medium, ans_medium)
        Sn_1 = np.matmul(np.matmul(si, Sn_1), si)

        if np.power(Sn_1 - Sn, 2).sum() <= 1e-10:
            break
        Sn = Sn_1
    COV_ideal_originD = Sn_1

    return mean_ideal_originD, COV_ideal_originD

# * ideal Monge map for Gaussian


def ideal_gaussian_map(mu1, sigma1, mu2, sigma2):
    inside = ln.sqrtm(sigma1) @ sigma2 @ ln.sqrtm(sigma1)
    B = LA.inv(ln.sqrtm(sigma1)) @ ln.sqrtm(inside) @ LA.inv(ln.sqrtm(sigma1))
    a = mu2 - B @ mu1
    return B, a


def ideal_gaussian_w2(mu1, sigma1, mu2, sigma2):
    w2 = LA.norm(
        mu1 - mu2, 2)**2 + np.matrix.trace(
            sigma1 + sigma2 - 2 * ln.sqrtm(
                ln.sqrtm(sigma1) @ sigma2 @ ln.sqrtm(sigma1)
            )
    )
    return w2


def gaussian_compare_package(samples, cov, mu=None):
    samples = samples if isinstance(samples, np.ndarray) else samples.cpu()
    mean_ours, cov_ours = mean_cov_from_samples(samples)
    if mu is None:
        mean_ideal = np.zeros_like(mean_ours)
    mean_error = Frobenius_absolute_error(mean_ours, mean_ideal)
    cov_error = Frobenius_relative_error(cov, cov_ours)
    # BW2_UVP = 100 * BW2_distance(
    #     mean_ideal, mean_ours, cov_ideal, cov_ours) * 2 / np.trace(cov_ideal)
    return mean_error, cov_error


def mean_cov_from_samples(miu):
    return mean_real_originD(miu), cov_real_originD(miu)


def cov_real_originD(miu):
    if type(miu) is torch.Tensor:
        return np.cov(miu.detach().numpy().T)
    return np.cov(miu.T)


def mean_real_originD(miu):
    if type(miu) is torch.Tensor:
        return np.mean(miu.detach().numpy(), 0)
    return np.mean(miu, 0)


def Frobenius_relative_error(COV_ideal_originD, COV_real_originD):
    return LA.norm(
        COV_ideal_originD - COV_real_originD) / LA.norm(COV_ideal_originD)


def Frobenius_absolute_error(COV_ideal_originD, COV_real_originD):
    return LA.norm(
        COV_ideal_originD - COV_real_originD)


def BW2_distance(mean_ideal, mean_ours, cov_ideal, cov_ours):
    under_squre = ln.sqrtm(cov_ours) @ cov_ideal @ ln.sqrtm(cov_ours)
    return 0.5 * LA.norm(mean_ideal - mean_ours)**2 + 0.5 * np.trace(cov_ideal) + 0.5 * np.trace(cov_ours) - np.trace(ln.sqrtm(under_squre))


def ideal_projected_param(args, weights_distribution, miu):
    mean_ideal_originD, COV_ideal_originD = ideal_gaussian_baryc_originD(
        args, weights_distribution)
    root_of_Sn = ln.sqrtm(COV_ideal_originD)
    COV_real_originD = np.cov(miu.T)

    if args.high_dim_flag:
        projection_component = np.zeros(
            [args.NUM_DISTRIBUTION, args.INPUT_DIM])

        projection_component[0, 1::2] = 1 / np.sqrt(args.INPUT_DIM / 2)
        projection_component[1, ::2] = 1 / np.sqrt(args.INPUT_DIM / 2)

        miu = np.matmul(miu, projection_component.T)

        mean_ideal = np.matmul(projection_component, mean_ideal)
        COV_ideal = np.matmul(
            np.matmul(projection_component, COV_ideal_originD), projection_component.T)
    else:
        COV_ideal = COV_ideal_originD

    W2_list = []
    for i in range(args.NUM_DISTRIBUTION):
        W2_list.append(
            LA.norm(args.MEAN[i][0, :] - mean_ideal_originD, 2)**2 +
            np.matrix.trace(COV_ideal_originD +
                            np.asmatrix(args.COV[i][0, :, :]) - 2 * ln.sqrtm(
                                np.matmul(
                                    np.matmul(root_of_Sn, np.asmatrix(args.COV[i][0, :, :])), root_of_Sn)
                            )
                            )
        )

    W2_total = sum(W2_list) / args.NUM_DISTRIBUTION
    half_moment = 0.5 * \
        (np.trace(COV_ideal_originD) +
         np.inner(mean_ideal_originD, mean_ideal_originD))

    return miu, mean_ideal, COV_ideal_originD, COV_real_originD, COV_ideal, W2_total, half_moment

#!mnist


def test_output_classi(model, device, test_loader):
    model.eval()
    with torch.no_grad():
        for data, _ in test_loader:
            data = data.to(device)
            output = model(data)
    return output


def test_correct_rate_classi(output, device, test_loader):

    for _, target in test_loader:
        target = target.cuda(device)
        test_loss = torch.nn.functional.mse_loss(
            output, target, reduction='sum').item()

        rounded_output = torch.round(output)
        correct = (rounded_output == target).sum().item()

    print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))


def test_mnist_classifier(model, device, test_loader, cfg=None):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            target = (target == 7).float()
            data, target = data.to(device), target.to(device)
            output = model(data)
            # test_loss += torch.nn.functional.mse_loss(
            #     output, target, reduction='sum').item()
            test_loss += (-target * torch.log(output) -
                          (1 - target) * torch.log(1 - output)).sum()

            # rounded_output = torch.round(output)
            # rounded_output = (output > output.mean()).float()
            rounded_output = (output > 0.5).float()
            correct += (rounded_output == target).sum().item()
            # import matplotlib.pyplot as plt
            # plt.scatter(np.arange(len(output)),output.cpu())
            # plt.savefig('vis0.png')
            # plt.close()

    test_loss /= len(test_loader.dataset)

    if cfg != None and hasattr(cfg, 'idx_subset'):
        print(cfg.idx_subset)

    print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

#! generate barycenter data


def pushforward(cfg, device, results_save_path=None, load_epoch=None, x_data=None, psb=False, return_marginal=False):
    if results_save_path is None:
        results_save_path = cfg.get_save_path()
    if x_data is None:
        data = (DTU.get_ideal_xy(cfg).float())[:, :cfg.N_TEST]
        x_data = data[0].cuda(device)
        y_data = data[1].cuda(device)

        # PLU.mnist_alone(data[0][:64], results_save_path +
        #                 f'/x.png', gan=True)
        # PLU.mnist_alone(data[1][:64], results_save_path +
        #                 f'/y.png', gan=True)

    phi, capital_t = g_NN.generate_monge_NN(cfg)

    if load_epoch is None:
        generator_t = g_NN.load_generator(
            results_save_path, capital_t, cfg.epochs, device=device)
    else:
        generator_t = g_NN.load_generator(
            results_save_path, capital_t, load_epoch, device=device)
        generator_phi = g_NN.load_generator(
            results_save_path, phi, load_epoch, device=device, choice='phi')

    if type(x_data) == np.ndarray:
        x_data = PTU.numpy2torch(x_data)

    y_psf = generator_t(x_data)
    # PLU.mnist_alone(y_psf[:64], results_save_path +
    # f'/{cfg.load_epoch}_{cfg.repeat}.png', gan = True)
    if psb:
        y_data = Variable(y_data, requires_grad=True)
        phi_y = generator_phi(y_data).sum()
        x_psf = y_data + torch.autograd.grad(phi_y, y_data)[0]
        return y_psf, x_psf
    elif return_marginal:
        return y_psf, data
    else:
        return y_psf
