import torch.utils.data
import torch
import numpy as np
import scipy
from scipy.stats import multivariate_normal
from src.utils.plot_utils import grid_NN_2_generator
import src.utils.pytorch_utils as PTU
import src.datamodules.data_utils as DTU
#! gaussian


def gaussian_data(num_samples, args):
    total_data = torch.zeros(
        num_samples, args.INPUT_DIM, 2)

    for i in range(2):
        weight_GMM = int(num_samples / args.NUM_GMM_COMPONENT[i])
        for j in range(args.NUM_GMM_COMPONENT[i]):
            total_data[(j * weight_GMM):((j + 1) * weight_GMM), :, i] = torch.from_numpy(np.random.multivariate_normal(
                args.MEAN[i][j, :], args.COV[i][j], weight_GMM))
        index_column = torch.randperm(num_samples)
        total_data[:, :, i] = total_data[index_column, :, i]
    return total_data


def importa_debug_data_gmm(args, mean_mu, cov_mu):
    train_data = torch.randn(
        args.N_TRAIN_SAMPLES, args.INPUT_DIM, 2)
    train_data[:, :, 0] = gaussian_data(
        args.N_TRAIN_SAMPLES, args)[:, :, 1]
    train_data[:, :, 1] = torch_samples_generate_Gaussian(
        args.N_TRAIN_SAMPLES, mean_mu, cov_mu
    )
    return train_data


def importa_samp_data_gauss(args):
    train_data = torch.randn(
        args.N_TRAIN_SAMPLES, args.INPUT_DIM, 2)
    train_data[:, :, 1] = torch_samples_generate_Gaussian(
        args.N_TRAIN_SAMPLES, args.INPUT_DIM * [0], np.eye(args.INPUT_DIM) * args.mu_var)
    return train_data


def importa_samp_data_gmm(args):
    assert args.mu_equal_q == 0 or args.P0_equal_Q == 0
    train_data = torch.randn(
        args.N_TRAIN_SAMPLES, args.INPUT_DIM, 2) * 1.5
    # * P0_equal_Q means P0 is the same as settings in record_mean_cov file.
    if args.P0_equal_Q:
        train_data[:, :, 0] = gaussian_data(
            args.N_TRAIN_SAMPLES, args)[:, :, 1]
    if args.mu_equal_q:
        train_data[:, :, 1] = gaussian_data(
            args.N_TRAIN_SAMPLES, args)[:, :, 1]
    else:
        train_data[:, :, 1] = torch_samples_generate_Gaussian(
            args.N_TRAIN_SAMPLES, args.INPUT_DIM * [0], np.eye(args.INPUT_DIM) * args.mu_var)
    return train_data

# * Only applied to porous media


def import_aggre(args):
    if args.INPUT_DIM == 2:
        return import_aggre_2d(args)
    elif args.INPUT_DIM == 1:
        return import_aggre_1d(args)


def import_aggre_1d(args):
    assert args.INPUT_DIM == 1
    train_data = torch.randn(
        args.N_TRAIN_SAMPLES, args.INPUT_DIM, 3)
    return train_data, None


def import_aggre_2d(args):
    assert args.INPUT_DIM == 2
    train_data = torch.randn(
        args.N_TRAIN_SAMPLES, args.INPUT_DIM, 3) * 0.5
    return train_data, None


def gaussian_ring_2d(num_samples):
    ring_list = []
    for idx in range(3):
        theta = torch.rand(int(num_samples / 3), 1) * 2 * np.pi
        cos_sin = torch.cat([torch.cos(theta), torch.sin(theta)], dim=1)
        radius = 1 + torch.randn(int(num_samples / 3), 1) * 0.05
        ring_data = radius * cos_sin
        mean_theta = idx * 2 * np.pi / 3
        ring_list.append(
            ring_data + 3 * torch.Tensor([[np.cos(mean_theta), np.sin(mean_theta)]]))
    return torch.cat(ring_list, dim=0)


def import_aggre_diffusion_2d(args):
    assert args.INPUT_DIM == 2
    train_data = torch.randn(
        args.N_TRAIN_SAMPLES, args.INPUT_DIM, 3)
    if args.keller_segel:
        train_data[:, :, 0] = gaussian_ring_2d(args.N_TRAIN_SAMPLES)
        train_data[:, :, 2] = gaussian_ring_2d(args.N_TRAIN_SAMPLES)
    else:
        train_data[:, :, 0] = torch.rand(
            args.N_TRAIN_SAMPLES, args.INPUT_DIM) * args.border_square * 2 - args.border_square
        train_data[:, :, 2] = torch.rand(
            args.N_TRAIN_SAMPLES, args.INPUT_DIM) * args.border_square * 2 - args.border_square

    max_bound = min(5, train_data[:, :, 0].abs().max().item() * args.q_bound_scale)

    train_data[:, :, 1] = torch.rand(
        args.N_TRAIN_SAMPLES, args.INPUT_DIM) * max_bound * 2 - max_bound
    volume = (max_bound * 2)**2
    return train_data, volume


def import_Barenblatt(args, rho0_density):
    if args.INPUT_DIM == 2:
        return import_Barenblatt_2d(args, rho0_density)
    elif args.INPUT_DIM == 1:
        return import_Barenblatt_1d(args, rho0_density)


def import_Barenblatt_1d(args, rho0_density):
    assert args.INPUT_DIM == 1
    train_data = torch.randn(
        args.N_TRAIN_SAMPLES, args.INPUT_DIM, 2)

    bound = 0.6
    x_array = np.linspace(-bound, bound, 2000)
    sqrt_p = np.sqrt(rho0_density(x_array))
    x_sqrt_p = x_array * sqrt_p
    umax = max(sqrt_p)
    vmin, vmax = min(x_sqrt_p), max(x_sqrt_p)

    train_data[:, 0, 0] = torch.from_numpy(scipy.stats.rvs_ratio_uniforms(
        rho0_density, umax, vmin, vmax, size=args.N_TRAIN_SAMPLES))
    max_bound = train_data[:, :, 0].abs().max(axis=0)[0] * args.q_bound_scale
    volume = torch.prod(max_bound * 2)
    train_data[:, :, 1] = torch.rand(
        args.N_TRAIN_SAMPLES, args.INPUT_DIM) * max_bound * 2 - max_bound
    return train_data, volume.item()


def import_Barenblatt_2d(args, rho0_density):
    assert args.INPUT_DIM == 2
    train_data = torch.randn(
        args.N_TRAIN_SAMPLES, args.INPUT_DIM, 2)

    bound = 1.2
    xy_list = grid_NN_2_generator(100, -bound, bound)[2]
    rho_matrix = rho0_density(xy_list).reshape(100, 100)
    train_data[:, :, 0] = sample_from_2d_matrix(
        rho_matrix, args.N_TRAIN_SAMPLES, add_noise=True) / (100 / 4) * bound
    # import src.utils.plot_utils as PLU
    # import matplotlib.pyplot as plt
    # plt.scatter(train_data[:, :, 0][:, 0], train_data[:, :, 0]
    #             [:, 1], color='darkslategray', alpha=0.1)
    # plt.savefig(f'epoch0.png', bbox_inches='tight', dpi=200)
    max_bound = train_data[:, :, 0].abs().max().item() * args.q_bound_scale
    volume = (max_bound * 2)**2
    train_data[:, :, 1] = torch.rand(
        args.N_TRAIN_SAMPLES, args.INPUT_DIM) * max_bound * 2 - max_bound
    return train_data, volume

# * images sampling!!


def sample_from_2d_matrix(n_n_matrix, num_samples, add_noise=True):
    n_n_matrix /= n_n_matrix.sum()
    num_grid = n_n_matrix.shape[0]
    inds = np.random.choice(
        np.arange(num_grid**2), p=n_n_matrix.reshape(-1), size=num_samples)
    inds = inds.astype('float')
    sample_xy = (np.array([inds % num_grid, inds // num_grid]).T - int(num_grid / 2)) / 2
    sample_xy[:, 1] *= -1

    if add_noise:
        noise = np.random.rand(inds.shape[0], 2) * 0.05 - 0.05 / 2
        sample_xy += noise
    return PTU.numpy2torch(sample_xy)


# * torch type


def torch_normal_gaussian(INPUT_DIM, **kwargs):
    N_TEST = kwargs.get('N_TEST')
    device = kwargs.get('device')
    kernel_size = kwargs.get('kernel_size')
    if N_TEST is None:
        epsilon_test = torch.randn(INPUT_DIM)
    elif kernel_size is None:
        epsilon_test = torch.randn(N_TEST, INPUT_DIM)
    else:
        epsilon_test = torch.randn(N_TEST, INPUT_DIM, kernel_size, kernel_size)
    return epsilon_test.cuda(device)


def torch_samples_generate_Gaussian(n, mean, cov, **kwargs):
    device = kwargs.get('device')
    if type(mean) == list:
        return torch.from_numpy(
            np.random.multivariate_normal(mean, cov, n)).float().cuda(device)
    else:
        return torch.distributions.multivariate_normal.MultivariateNormal(mean, cov).sample([n]).cuda(device)


def torch_samples_generate_GM(num_samples, mean_list, cov_list):
    input_dim, NUM_GMM_COMPONENT = DTU.get_gm_dim_component_one(mean_list)
    total_data = torch.zeros(
        num_samples, input_dim)
    weight_GMM = int(num_samples / NUM_GMM_COMPONENT)
    for j in range(NUM_GMM_COMPONENT):
        total_data[(j * weight_GMM):((j + 1) * weight_GMM)] = torch.from_numpy(np.random.multivariate_normal(
            mean_list[j], cov_list[j], weight_GMM))
    index_column = torch.randperm(num_samples)
    total_data = total_data[index_column]
    return total_data

# * numpy type


def repeat_list(ndarray, repeat_times):
    return [ndarray] * repeat_times


def np_samples_generate_Gaussian(mean, cov, n):
    Gaussian_sample = np.random.multivariate_normal(mean, cov, n)
    return Gaussian_sample


def np_PDF_generate_multi_normal(pos_n_d, mean, cov):
    # *the input maybe pos_n_n_2 in 2D plot application
    rv = multivariate_normal(mean, cov)
    multi_normal = rv.pdf(pos_n_d)
    return multi_normal


def np_PDF_generate_multi_normal_NN_1(pos_n_n_2, mean, cov):
    multi_normal_nn = np_PDF_generate_multi_normal(
        pos_n_n_2, mean, cov).reshape(-1, 1)[:, 0]
    return multi_normal_nn


def np_generate_kde_NN_1(pos_nn_2, kde_analyzer):
    kde_nn_1 = kde_analyzer.score_samples(pos_nn_2)
    kde_nn_1 = np.exp(kde_nn_1)
    return kde_nn_1
