from __future__ import print_function
import ot
import torch
import numpy as np
import scipy
from sklearn.neighbors import KernelDensity
import src.datamodules.generate_data as g_data
from src.datamodules.record_mean_cov import select_mean_and_cov_gmm

# * images


# def get_energy_image(path, embed_size):
#     # e.g. path ='./data/smiley.jpg'
#     img = mpimg.imread(path)
#     img_density, img_energy = prepare_image(
#         img, embed=(embed_size, embed_size),  # crop=(10, 710, 240, 940),
#         white_cutoff=225, gauss_sigma=3, background=0.01)
#     return img_density, img_energy


'''
PyTorch type
'''


def kde_Gaussian_fitting(miu, bandwidth):
    kde_analyzer = KernelDensity(
        kernel='gaussian', bandwidth=bandwidth).fit(miu)
    return kde_analyzer


def second_moment_no_average(batch_dim):
    return batch_dim.pow(2).sum(dim=1)


def second_moment_single_dist(batch_dim):
    return batch_dim.pow(2).sum(dim=1).mean()


def second_moment_all_dist(batch_dim_dist):
    return batch_dim_dist.pow(2).sum(dim=1).mean(dim=0)


def inprod_average(batch_dim_1, batch_dim_2):
    assert batch_dim_1.shape[0] == batch_dim_2.shape[0]
    batch_size = batch_dim_1.shape[0]
    inner_product_avg = torch.dot(batch_dim_1.reshape(-1),
                                  batch_dim_2.reshape(-1)) / batch_size
    return inner_product_avg


def inprod(batch_dim_1, batch_dim_2):
    innner_product = torch.dot(batch_dim_1.reshape(-1),
                               batch_dim_2.reshape(-1))
    return innner_product


def grad_of_function(input_samples, network):
    g_of_y = network(input_samples).sum()
    gradient = torch.autograd.grad(
        g_of_y, input_samples, create_graph=True)[0]
    return gradient


'''
localized POT library
'''


def w2_distance_gaussian(mean1, mean2, cov1, cov2):
    return ((mean1 - mean2)**2).sum() + np.trace(cov1 + cov2 - 2 * scipy.linalg.sqrtm(scipy.linalg.sqrtm(cov1) @ cov2 @ scipy.linalg.sqrtm(cov1)))


def w2_distance_samples_solver(sample1_n_d, sample2_n_d):
    # see here for details
    # https://pythonot.github.io/all.html#ot.emd
    # https://pythonot.github.io/all.html#ot.emd2
    assert sample1_n_d.shape == sample2_n_d.shape
    num_sample = sample1_n_d.shape[0]
    a = np.ones([num_sample]) / num_sample
    b = np.ones([num_sample]) / num_sample
    tmp_marginal_1 = np.expand_dims(sample1_n_d, axis=0)
    tmp_marginal_2 = np.expand_dims(sample2_n_d, axis=1)
    M = tmp_marginal_1 - tmp_marginal_2
    M = np.sum(np.abs(M)**2, axis=2)
    return ot.emd2(a, b, M)


'''
Gaussian utils
'''


def get_gmm_param(trial, num_component=9, seed=1):
    MEAN, COV = select_mean_and_cov_gmm(trial, num_component, seed)
    INPUT_DIM, NUM_GMM_COMPONENT = get_gm_dim_component_all(MEAN)
    return MEAN, COV, INPUT_DIM, NUM_GMM_COMPONENT


def get_gm_dim_component_all(mean_list):
    INPUT_DIM = mean_list[0].shape[1]
    NUM_GMM_COMPONENT = []
    for i in range(2):
        NUM_GMM_COMPONENT.append(mean_list[i].shape[0])
    return INPUT_DIM, NUM_GMM_COMPONENT


def get_gm_dim_component_one(mean_list):
    INPUT_DIM = mean_list.shape[1]
    NUM_GMM_COMPONENT = mean_list.shape[0]
    return INPUT_DIM, NUM_GMM_COMPONENT


'''
get xy data handle
'''


def get_ideal_xy(cfg):
    type_data = cfg.type_data
    if type_data == 'circ_squa':
        marginal_data = g_data.marginal_data_circ_squa(
            cfg)
    elif type_data == 'mnist0-1':
        marginal_data = g_data.marginal_mnist_3loop_ficnn_handle(
            cfg)
    elif type_data == '3digit':
        marginal_data = g_data.marginal_data_3digit_3loop_ficnn(
            cfg)[:, :, :-1]
    elif type_data == 'ellipse':
        marginal_data = g_data.marginal_data_ellipse_3loop_ficnn(
            cfg)[:, :, :-1]
    elif type_data == 'line':
        marginal_data = g_data.marginal_data_line_3loop_ficnn(
            cfg)[:, :, :-1]
    elif type_data == 'usps_mnist':
        marginal_data = g_data.marginal_data_usps_mnist(
            cfg)
    elif type_data == 'mnist_group':
        if cfg.N_TEST == 25:
            idx_digit = torch.zeros(25).long()
            for idx in range(5):
                idx_digit[idx * 5:(idx + 1) * 5] = 5000 * idx + torch.arange(5)
            marginal_data = g_data.marginal_mnist_3loop_ficnn_handle(
                cfg)[idx_digit]
        else:
            marginal_data = g_data.marginal_mnist_3loop_ficnn_handle(
                cfg)[torch.randperm(25000)]
    elif type_data == 'cifar':
        marginal_data = g_data.marginal_cifar_handle(cfg)
    elif type_data == 'Gauss2Gauss':
        marginal_data = g_data.marginal_data_gmm(cfg)
    return marginal_data.permute(2, 0, 1)
