import time
import numpy as np
import pymanopt
from pymanopt import tools
from pymanopt.optimizers.optimizer import Optimizer, OptimizerResult
from pymanopt.optimizers.steepest_descent import SteepestDescent
from localbasis_utils import compare_basis_componentwise
import torch
from tqdm import tqdm
from localbasis_utils import get_random_local_basis
from pymanopt.manifolds import SpecialOrthogonalGroup, Grassmann

''' pymanopt/pymanopt/optimizers/nelder_mead.py at pymanopt v2.0.1 '''
def compute_centroid(manifold, points, max_iterations=15, max_time=1000):
    """Compute the centroid of `points` on the `manifold` as Karcher mean."""

    @pymanopt.function.numpy(manifold)
    def objective(*y):
        if manifold.num_values == 1:
            (y,) = y
        return sum([manifold.dist(y, point) ** 2 for point in points]) / 2

    @pymanopt.function.numpy(manifold)
    def gradient(*y):
        if manifold.num_values == 1:
            (y,) = y
        return -sum(
            [manifold.log(y, point) for point in points],
            manifold.zero_vector(y),
        )

    optimizer = SteepestDescent(max_iterations=max_iterations, verbosity=2, max_time=max_time)
    #optimizer = SteepestDescent(max_iterations=15, verbosity=0)
    problem = pymanopt.Problem(
        manifold, objective, riemannian_gradient=gradient
    )
    #return optimizer.run(problem).point
    return optimizer.run(problem)

def sample_random_local_basis(model, eval_config, full=False):
    torch.autograd.set_grad_enabled(True)
    rng = np.random.RandomState(eval_config.seed)
    
    local_basis_list = []
    subspace_dim = eval_config.subspace_dim
    for _ in tqdm(range(eval_config.n_samples)):
        noise, z, z_local_basis, z_sv, noise_basis, rank, noise_level = get_random_local_basis(model, rng,
                                                                            last_layer_name=eval_config.last_layer_name,
                                                                            rankEst=False,
                                                                            sv_thres_ratio=eval_config.sv_thres_ratio)
        if full:
            local_basis_list.append(z_local_basis.numpy())
        else:
            local_basis_list.append(z_local_basis[:, :subspace_dim].numpy())
    return local_basis_list

def is_proj_to_SO(proj_localbasis):
    return np.linalg.det(proj_localbasis) > 0

def proj_to_SO(localbasis_projected):
    '''
    input: local basis projected onto the frechet mean (optimal subspace)
    '''
        
    to_be_proj = np.copy(localbasis_projected)
    if not is_proj_to_SO(to_be_proj):
        # If the projection onto the orthogonal group does not fall into the specical orthogonal group, 
        # correct it by changing the direction of the last component.
        to_be_proj[:, -1] *= -1
        
    #Projection onto the orthogonal group
    u, s, vh = np.linalg.svd(to_be_proj, full_matrices=True, compute_uv=True)
    return np.matmul(u, vh)
        
def align_to_reference(reference, special_orthogonal_points):
    aligned_points = []
    for point in special_orthogonal_points:
        _, aligned = align_local_basis(reference, point)
        aligned_points.append(aligned.numpy())
    return aligned_points
        
        
def align_local_basis(reference_basis, align_basis):
    '''
    Align "align_basis" to make <i-th refer, i-th align> > 0 for all i.
    reference_basis = (basis_dim, N), align_basis = (basis_dim, M)
    Each basis should be normalized.
    '''
    if isinstance(reference_basis, np.ndarray): reference_basis = torch.from_numpy(reference_basis)
    if isinstance(align_basis, np.ndarray): align_basis = torch.from_numpy(align_basis)
    
    sim_matrix, basis_orient = compare_basis_componentwise(reference_basis.t(), align_basis.t())
    aligned_basis = align_basis.detach() * basis_orient.diag()
    return reference_basis.detach(), aligned_basis    
    
def compute_frechet_basis(frechet_subspace, local_basis_list, max_iterations=100, max_time=3000):
    '''
    frechet_subspace = Target Subspace where we find a frechet basis, 
                       Shape = (ambient_dim, subspace_dim)
    local_basis_list = list of local basis
                       Shape = (ambient_dim, number of basis component)
    '''
    ambient_dim, subspace_dim = frechet_subspace.shape
    local_basis_crop_list = [local_basis[:, :subspace_dim] for local_basis in local_basis_list]
    
    aligned_local_basis_crop_list = align_to_reference(frechet_subspace.astype(np.float32), local_basis_crop_list)
    # aligned_proj_localbasis_list = [np.matmul(frechet_subspace.transpose(), localbasis) for localbasis in aligned_local_basis_crop_list]
    # aligned_SO_proj_localbasis_list = []
    # for idx, proj_localbasis in enumerate(aligned_proj_localbasis_list):
    #     aligned_SO_proj_localbasis = proj_to_SO(proj_localbasis)
    #     aligned_SO_proj_localbasis_list.append(aligned_SO_proj_localbasis)
    aligned_SO_proj_localbasis_list = [proj_to_SO(np.matmul(frechet_subspace.transpose(), localbasis)) for localbasis in aligned_local_basis_crop_list]
    
    ambient_dim = aligned_SO_proj_localbasis_list[0].shape[0]
    manifold = SpecialOrthogonalGroup(ambient_dim, k=1)
    opt_result = compute_centroid(manifold, aligned_SO_proj_localbasis_list, max_iterations=max_iterations, max_time=max_time)
    aligned_frechet_global_basis = opt_result.point
    aligned_frechet_global_basis_amb = np.matmul(frechet_subspace, aligned_frechet_global_basis)
    return aligned_frechet_global_basis_amb

def grass_geodesic_interp(grass_1, grass_2, n_step=5, overshoot=False):
    '''
    Geodesic Interpolation on Grassmannian Manifold from grass_1 to grass_2
    (Start and End point included)
    grass_i : shape= (Ambient dim, subspace dim) 
    '''
    ytx = grass_2.transpose() @ grass_1
    At = grass_2.transpose() - ytx @ grass_1.transpose()
    Bt = np.linalg.solve(ytx, At)
    u, s, vt = np.linalg.svd(Bt.transpose(), full_matrices=False)
    theta = np.expand_dims(np.arctan(s), -2)

    interps = []
    for i in np.linspace(0, 1, num=n_step, endpoint=True):
        interp = grass_1 @ (vt.transpose() * np.cos(theta * i)) @ vt  + (u * np.sin(theta * i)) @ vt
        interps.append(interp)

    if overshoot:
        step_size = 1 / n_step
        left_over_idx, right_over_idx = (-step_size), (1+step_size)
        left_over = grass_1 @ (vt.transpose() * np.cos(theta * left_over_idx)) @ vt + (u * np.sin(theta * left_over_idx)) @ vt
        right_over =  grass_1 @ (vt.transpose() * np.cos(theta * right_over_idx)) @ vt + (u * np.sin(theta * right_over_idx)) @ vt
        interps = [left_over] + interps + [right_over]
    return interps