import matplotlib.pyplot as plt
import numpy as np
from modules.utils import plotting
from sklearn.decomposition import PCA


def make_rotation_matrix_2d(angles):
    """
    Create rotation matrices of shape (num_angles, 2, 2)
    Args:
        angles: array of angles shaped as (num_angles,1)

    Returns:

    """
    cos_angle = np.cos(angles)
    sin_angle = np.sin(angles)
    # Stack the
    matrix = np.stack((cos_angle, -sin_angle,
                       sin_angle, cos_angle),
                      axis=-1)
    # Sha
    output_shape = angles.shape + (2, 2)
    return matrix.reshape(output_shape)


def plot_principal_components(z, colors_flat):
    fig = plt.figure(figsize=(5, 5))
    plt.scatter(z[:, 0], z[:, 1], c=colors_flat)




def create_combinations_k_values_range(start_value=-10, end_value=10):
    values = range(start_value, end_value + 1)
    num_k = len(values)
    k_values = np.array(np.meshgrid(values, values))
    k_values = np.moveaxis(k_values.reshape((2, num_k * num_k)), -1, 0)
    return k_values


def calculate_metric_k_list(z_loc, k_values, verbose = 0):
    metric_values = np.zeros(len(k_values))
    for num_k, k in enumerate(k_values):
        metric_values[num_k] = pca_metric2(z_loc, k=k)
        if verbose == 1:
            print("Combination number {} for k = {}, score = {}".format(num_k, k, metric_values[num_k]))
    score = np.amin(metric_values)
    k_min = k_values[np.argmin(metric_values)]
    return score, k_min

def pca_metric2(z, angles_combinations=None, k=None, plot=False, verbose=0):
    # Regular spacing of angles in [0,2pi)
    if angles_combinations is None:
        angles1, angles2, z_dim = z.shape
        angles_regular1 = np.linspace(0, 1, angles1, endpoint=False) * 2 * np.pi
        angles_regular2 = np.linspace(0, 1, angles2, endpoint=False) * 2 * np.pi
        # All possible combinations of the regular angles
        angles_combinations = np.array(np.meshgrid(angles_regular1, angles_regular2))

    # color map
    angles_flat = angles_combinations.reshape(2, -1)
    colors_flat = plotting.yiq_embedding(angles_flat[0], angles_flat[1])

    # The mean latent for each group (angles_per_group, z_dim, n_groups)
    mean_latent = np.stack([np.mean(z, axis=i) for i in range(2)], axis=-1)
    g1_elements = z - np.expand_dims(mean_latent[:, :, 0], axis=0)  # Variations w.r.t. G1
    g2_elements = z - np.expand_dims(mean_latent[:, :, 1], axis=1)  # Variations w.r.t. G2
    g1_elements_flat = g1_elements.reshape(np.product(g1_elements.shape[:-1]), g1_elements.shape[-1])
    g2_elements_flat = g2_elements.reshape(np.product(g2_elements.shape[:-1]), g2_elements.shape[-1])
    # plot_latent_dimension_combinations(g1_elements_flat, colors_flat)
    # plot_latent_dimension_combinations(g2_elements_flat, colors_flat)

    # PCA
    pca1 = PCA(n_components=2, svd_solver="full")
    pca2 = PCA(n_components=2, svd_solver="full")
    pca1.fit(g1_elements_flat)
    pca2.fit(g2_elements_flat)
    pca1_components = pca1.components_
    pca2_components = pca2.components_
    if verbose == 1:
        print("PCA Components 1", pca1_components)
        print("PCA Components 2", pca2_components)

    singular_values1 = pca1.singular_values_
    singular_values2 = pca2.singular_values_
    eigenvalues1 = singular_values1 ** 2 / (g1_elements_flat.shape[0] - 1)
    eigenvalues2 = singular_values2 ** 2 / (g2_elements_flat.shape[0] - 1)
    std1 = np.sqrt(eigenvalues1)
    std2 = np.sqrt(eigenvalues2)
    if verbose == 1:
        print("Singularvalues1", singular_values1)
        print("Eigenvalues1", eigenvalues1)
        print("std1", std1)
        print("Singularvalues2", singular_values2)
        print("Eigenvalues2", eigenvalues2)
        print("std2", std2)

    projected_g1 = pca1.transform(g1_elements_flat)
    projected_g2 = pca2.transform(g2_elements_flat)
    projected_g1 /= std1 * np.sqrt(2)
    projected_g2 /= std2 * np.sqrt(2)
    # plot_latent_dimension_combinations(projected_g1, colors_flat)
    # plot_latent_dimension_combinations(projected_g2, colors_flat)

    # K1 and K2
    if k is None:
        # Calculate the true angles in [-pi,pi)
        # Estimate the perfect embeddings
        z_p = np.stack([np.cos(angles_combinations[0]), np.sin(angles_combinations[0]),
                        np.cos(angles_combinations[1]), np.sin(angles_combinations[1])],
                       axis=-1)
        z_p_flat = z_p.reshape((np.product(z_p.shape[:-1]), z_p.shape[-1]))
        # Estimate the angles for the perfect embeddings
        angles_flat2 = np.stack(
            [np.arctan2(z_p_flat[:, 1], z_p_flat[:, 0]), np.arctan2(z_p_flat[:, 3], z_p_flat[:, 2])], axis=0)

        # Angles
        angles_starred = np.zeros(angles_flat2.shape)
        # Calculate the beta value for the first datapoint
        beta_1 = np.arctan2(projected_g1[0, 1], projected_g1[0, 0])
        beta_2 = np.arctan2(projected_g2[0, 1], projected_g2[0, 0])
        # Calculate the rotation matrix
        rotation_1 = make_rotation_matrix_2d(-beta_1)
        rotation_2 = make_rotation_matrix_2d(-beta_2)

        projected_g1_r = np.squeeze(np.matmul(rotation_1, np.expand_dims(projected_g1, -1)), -1)
        projected_g2_r = np.squeeze(np.matmul(rotation_2, np.expand_dims(projected_g2, -1)), -1)
        # There is probably something about the axis I need to check
        angles_starred[1, :] = np.arctan2(projected_g1_r[:, 1], projected_g1_r[:, 0])
        angles_starred[0, :] = np.arctan2(projected_g2_r[:, 1], projected_g2_r[:, 0])
        ratios = (angles_starred + np.finfo(float).eps) / (angles_flat2 + np.finfo(float).eps)

        k1 = np.round(np.mean(ratios[0, :]), 0)
        k2 = np.round(np.mean(ratios[1, :]), 0)

        k1 = np.round(k1, 0)
        k2 = np.round(k2, 0)
        if verbose == 1:
            print("k1 = ", k1, "k2 = ", k2)
    else:
        k1 = k[0]
        k2 = k[1]
    # k1 = 1
    # k2 = 1

    # make rotation matrices for inverse transformations, shape (..., 2, 2)
    inv_rotations1 = make_rotation_matrix_2d(- k1 * angles_flat[1])
    inv_rotations2 = make_rotation_matrix_2d(- k2 * angles_flat[0])

    # add dim of size 1 at the end, so shape is (..., 2, 1)
    pc_g1 = np.expand_dims(projected_g1, axis=-1)
    pc_g2 = np.expand_dims(projected_g2, axis=-1)

    # apply inverse rotations
    pc_g1 = np.matmul(inv_rotations1, pc_g1)  # resulting shape (..., 2, 1)
    pc_g2 = np.matmul(inv_rotations2, pc_g2)  # resulting shape (..., 2, 1)
    pc_g1 = np.squeeze(pc_g1, axis=-1)  # resulting shape (*batch_dims, 2)
    pc_g2 = np.squeeze(pc_g2, axis=-1)  # resulting shape (*batch_dims, 2)

    if plot:
        plot_principal_components(pc_g1, colors_flat)
        plot_principal_components(pc_g2, colors_flat)

    # compute metric
    mean1 = np.mean(pc_g1, axis=0)
    mean2 = np.mean(pc_g2, axis=0)
    var1 = np.mean(np.sum((pc_g1 - mean1) ** 2, axis=-1), axis=0)
    var2 = np.mean(np.sum((pc_g2 - mean2) ** 2, axis=-1), axis=0)
    if verbose == 1:
        print("Variance 1:", var1)
        print("Variance 2:", var2)
    score = var1 + var2
    if verbose == 1:
        print("Final metric score:", score)
    return var1, var2

#
# def pca_metric_only_rotations(z, k=None, verbose=0):
#     """
#     Calculate the metric
#     Args:
#         z:
#         angles_combinations:
#         k:
#         verbose:
#
#     Returns:
#
#     """
#     # Regular spacing of angles in [0,2pi)
#
#     num_objects, angles, z_dim = z.shape
#     angles_regular = np.expand_dims(np.linspace(0, 1, angles, endpoint=False) * 2 * np.pi, 0)
#     angles_regular = np.concatenate([angles_regular] * num_objects, axis=0)
#     angles_flat = angles_regular.reshape(num_objects * angles)
#     # The mean latent for each group (angles_per_group, z_dim, n_groups)
#     mean_latent = np.mean(z, axis=1)
#     g_elements = z - np.expand_dims(mean_latent, axis=1)  # Variations w.r.t. G2
#     g_elements_flat = g_elements.reshape(np.product(g_elements.shape[:-1]), g_elements.shape[-1])
#
#     # PCA
#     pca = PCA(n_components=2, svd_solver="full")
#     pca.fit(g_elements_flat)
#
#     projected_g = pca.transform(g_elements_flat)
#     # Calculate normalization constant
#     singular_values = pca.singular_values_
#     eigenvalues = singular_values ** 2 / (g_elements_flat.shape[0] - 1)
#     std = np.sqrt(eigenvalues)
#     projected_g /= std * np.sqrt(2)
#     #     plt.scatter(projected_g[:, 0], projected_g[:, 1], c=range(angles), cmap="Reds")
#
#     k2 = k
#
#     # make rotation matrices for inverse transformations, shape (..., 2, 2)
#     inv_rotations = make_rotation_matrix_2d(- k2 * angles_flat)
#     # add dim of size 1 at the end, so shape is (..., 2, 1)
#     pc_g = np.expand_dims(projected_g, axis=-1)
#
#     # apply inverse rotations
#     pc_g = np.matmul(inv_rotations, pc_g)  # resulting shape (..., 2, 1)
#     pc_g = np.squeeze(pc_g, axis=-1)  # resulting shape (*batch_dims, 2)
#
#     # compute metric
#     mean = np.mean(pc_g, axis=0)
#
#     var = np.mean(np.sum((pc_g - mean) ** 2, axis=-1), axis=0)
#     if verbose == 1:
#         print("Variance:", var)
#     score = var
#     if verbose == 1:
#         print("Final metric score:", score)
#     return var


def pca_metric_only_rotations(z, k=None, verbose=0):
    """
    Calculate the metric
    Args:
        z:
        angles_combinations:
        k:
        verbose:

    Returns:

    """
    # Regular spacing of angles in [0,2pi)

    num_objects, angles, z_dim = z.shape
    angles_regular = np.expand_dims(np.linspace(0, 1, angles, endpoint=False) * 2 * np.pi, 0)
    angles_regular = np.concatenate([angles_regular] * num_objects, axis=0)
    angles_flat = angles_regular.reshape(num_objects * angles)
    # The mean latent for each group (angles_per_group, z_dim, n_groups)
    mean_latent = np.mean(z, axis=1)
    g_elements = z - np.expand_dims(mean_latent, axis=1)  # Variations w.r.t. G2
    g_elements_flat = g_elements.reshape(np.product(g_elements.shape[:-1]), g_elements.shape[-1])

    # PCA
    pca = PCA(n_components=2, svd_solver="full")
    pca.fit(g_elements_flat)

    projected_g = pca.transform(g_elements_flat)
    # Calculate normalization constant
    singular_values = pca.singular_values_
    eigenvalues = singular_values ** 2 / (g_elements_flat.shape[0] - 1)
    std = np.sqrt(eigenvalues)
    projected_g /= std * np.sqrt(2)
    #     plt.scatter(projected_g[:, 0], projected_g[:, 1], c=range(angles), cmap="Reds")

    k2 = k

    # make rotation matrices for inverse transformations, shape (..., 2, 2)
    inv_rotations = make_rotation_matrix_2d(- k2 * angles_flat)
    # add dim of size 1 at the end, so shape is (..., 2, 1)
    pc_g = np.expand_dims(projected_g, axis=-1)

    # apply inverse rotations
    pc_g = np.matmul(inv_rotations, pc_g)  # resulting shape (..., 2, 1)

    pc_g = np.squeeze(pc_g, axis=-1)
    pc_g = pc_g.reshape((num_objects, angles, pc_g.shape[-1]))

    # compute metric
    mean = np.expand_dims(np.mean(pc_g, axis=1), 1)
    var = np.mean(np.sum((pc_g - mean) ** 2, axis=-1))
    if verbose == 1:
        print("Variance:", var)
    score = var
    if verbose == 1:
        print("Final metric score:", score)
    return var


def calculate_metric_rotations(z_loc, k_values, verbose = 0):
    metric_values = np.zeros((len(k_values)))
    for num_k, k in enumerate(k_values):
        metric_values_per_object = []
        metric_values_per_object.append(pca_metric_only_rotations(z_loc, k=k, verbose=0))
        metric_values[num_k] = np.mean(np.array(metric_values_per_object))
        print("Combination number {} for k = {}, score = {}".format(num_k, k, metric_values[num_k]))
    score = np.amin(metric_values[:])  # Assume that the transformations on index 1 are the ones of interest
    k_min = k_values[np.argmin(metric_values[:])]
    return score, k_min




def pca_metric(z, k=None):
    angles1, angles2, z_dim = z.shape

    # Regular spacing of angles in [0,2pi)
    angles_regular1 = np.linspace(0, 1, angles1, endpoint=False) * 2 * np.pi
    angles_regular2 = np.linspace(0, 1, angles2, endpoint=False) * 2 * np.pi
    # All possible combinations of the regular angles
    angles_combinations = np.array(np.meshgrid(angles_regular1, angles_regular2))

    # color map
    angles_flat = angles_combinations.reshape(2, -1)
    colors_flat = plotting.yiq_embedding(angles_flat[0], angles_flat[1])

    # The mean latent for each group (angles_per_group, z_dim, n_groups)
    mean_latent = np.stack([np.mean(z, axis=i) for i in range(2)], axis=-1)
    g1_elements = z - np.expand_dims(mean_latent[:, :, 0], axis=0)  # Variations w.r.t. G1
    g2_elements = z - np.expand_dims(mean_latent[:, :, 1], axis=1)  # Variations w.r.t. G2
    g1_elements_flat = g1_elements.reshape(np.product(g1_elements.shape[:-1]), g1_elements.shape[-1])
    g2_elements_flat = g2_elements.reshape(np.product(g2_elements.shape[:-1]), g2_elements.shape[-1])
    # plot_latent_dimension_combinations(g1_elements_flat, colors_flat)
    # plot_latent_dimension_combinations(g2_elements_flat, colors_flat)

    # PCA
    pca1 = PCA(n_components=2, svd_solver="full")
    pca2 = PCA(n_components=2, svd_solver="full")
    pca1.fit(g1_elements_flat)
    pca2.fit(g2_elements_flat)
    pca1_components = pca1.components_
    pca2_components = pca2.components_
    print("PCA Components 1", pca1_components)
    print("PCA Components 2", pca2_components)

    singular_values1 = pca1.singular_values_
    singular_values2 = pca2.singular_values_
    eigenvalues1 = singular_values1 ** 2 / (g1_elements_flat.shape[0] - 1)
    eigenvalues2 = singular_values2 ** 2 / (g2_elements_flat.shape[0] - 1)
    std1 = np.sqrt(eigenvalues1)
    std2 = np.sqrt(eigenvalues2)
    print("Singularvalues1", singular_values1)
    print("Eigenvalues1", eigenvalues1)
    print("std1", std1)
    print("Singularvalues2", singular_values2)
    print("Eigenvalues2", eigenvalues2)
    print("std2", std2)

    projected_g1 = pca1.transform(g1_elements_flat)
    projected_g2 = pca2.transform(g2_elements_flat)
    projected_g1 /= std1 * np.sqrt(2)
    projected_g2 /= std2 * np.sqrt(2)
    # plot_latent_dimension_combinations(projected_g1, colors_flat)
    # plot_latent_dimension_combinations(projected_g2, colors_flat)

    # K1 and K2
    if k is None:
        # Calculate the true angles in [-pi,pi)
        # Estimate the perfect embeddings
        z_p = np.stack([np.cos(angles_combinations[0]), np.sin(angles_combinations[0]),
                        np.cos(angles_combinations[1]), np.sin(angles_combinations[1])],
                       axis=-1)
        z_p_flat = z_p.reshape((np.product(z_p.shape[:-1]), z_p.shape[-1]))
        # Estimate the angles for the perfect embeddings
        angles_flat2 = np.stack(
            [np.arctan2(z_p_flat[:, 1], z_p_flat[:, 0]), np.arctan2(z_p_flat[:, 3], z_p_flat[:, 2])], axis=0)

        # Angles
        angles_starred = np.zeros(angles_flat2.shape)
        # Calculate the beta value for the first datapoint
        beta_1 = np.arctan2(projected_g1[0, 1], projected_g1[0, 0])
        beta_2 = np.arctan2(projected_g2[0, 1], projected_g2[0, 0])
        # Calculate the rotation matrix
        rotation_1 = make_rotation_matrix_2d(-beta_1)
        rotation_2 = make_rotation_matrix_2d(-beta_2)

        projected_g1_r = np.squeeze(np.matmul(rotation_1, np.expand_dims(projected_g1, -1)), -1)
        projected_g2_r = np.squeeze(np.matmul(rotation_2, np.expand_dims(projected_g2, -1)), -1)
        # There is probably something about the axis I need to check
        angles_starred[1, :] = np.arctan2(projected_g1_r[:, 1], projected_g1_r[:, 0])
        angles_starred[0, :] = np.arctan2(projected_g2_r[:, 1], projected_g2_r[:, 0])
        print(angles_starred)
        print(angles_flat2)
        ratios = (angles_starred + np.finfo(float).eps) / (angles_flat2 + np.finfo(float).eps)

        k1 = np.mean(ratios[0, :])
        k2 = np.mean(ratios[1, :])

        print(k1, k2)
        k1 = np.round(k1, 0)
        k2 = np.round(k2, 0)
        print(k1, k2)
    else:
        k1 = k[0]
        k2 = k[1]
    # k1 = 1
    # k2 = 1

    # make rotation matrices for inverse transformations, shape (..., 2, 2)
    inv_rotations1 = make_rotation_matrix_2d(- k1 * angles_flat[1])
    inv_rotations2 = make_rotation_matrix_2d(- k2 * angles_flat[0])

    # add dim of size 1 at the end, so shape is (..., 2, 1)
    pc_g1 = np.expand_dims(projected_g1, axis=-1)
    pc_g2 = np.expand_dims(projected_g2, axis=-1)

    # apply inverse rotations
    pc_g1 = np.matmul(inv_rotations1, pc_g1)  # resulting shape (..., 2, 1)
    pc_g2 = np.matmul(inv_rotations2, pc_g2)  # resulting shape (..., 2, 1)
    pc_g1 = np.squeeze(pc_g1, axis=-1)  # resulting shape (*batch_dims, 2)
    pc_g2 = np.squeeze(pc_g2, axis=-1)  # resulting shape (*batch_dims, 2)

    plot_principal_components(pc_g1, colors_flat)
    plot_principal_components(pc_g2, colors_flat)

    # compute metric
    mean1 = np.mean(pc_g1, axis=0)
    mean2 = np.mean(pc_g2, axis=0)
    var1 = np.mean(np.sum((pc_g1 - mean1) ** 2, axis=-1), axis=0)
    var2 = np.mean(np.sum((pc_g2 - mean2) ** 2, axis=-1), axis=0)
    print("Variance 1:", var1)
    print("Variance 2:", var2)
    score = var1 + var2
    print("Final metric score:", score)
    return score
