from notebook_init import *
import time, os
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt

out_root = Path('out/consistency')
makedirs(out_root, exist_ok=True)
# rand = lambda : np.random.randint(np.iinfo(np.int32).max)

# StyleGAN2
use_w = True
#dataset = 'ffhq'   #config-f
dataset = 'ffhq-config-e'
inst = get_instrumented_model('StyleGAN2', dataset, 'style', device, inst=inst, use_w=use_w)
model = inst.model
model.truncation = 1.0

#Load directions
#gs_dir = np.load('./global_directions/ganspace_directions_ffhq_stylegan2.npy')#Note! Only ffhq is provided.
gs_dir = np.load('./global_directions/ganspace_directions_ffhq_StyleGAN2_style-8.npy')#Note! Only ffhq is provided.
gs_dir = torch.from_numpy(gs_dir).to(device)
sf_dir = np.load('./global_directions/sefa_directions_ffhq_stylegan2.npy')#Note! Only ffhq is provided.
sf_dir = torch.from_numpy(sf_dir).to(device)
class compare_basis_config:
    n_samples = 50
    seed = 0
    subspace_dim = 2
    rankEst = False
    sv_thres_ratio =  0.001
    last_layer_name = '8'

torch.autograd.set_grad_enabled(True)
eval_config = compare_basis_config()

model_name = 'StyleGAN2'

subLayerNames = [str(a) for a in range(1, 9)]
#subLayerNames = [f'dense{a}_act' for a in range(0, 8)]
sv_thres_ratio_candi = [0.0005, 0.001, 0.005, 0.01]
print(subLayerNames)

with open(f'./subnetwork_stats/LayerThres2rank_dict_{model_name}.dill', 'rb') as f:
    LayerThres2rank_dict = pickle.load(f)

''' pymanopt ver 2.0.1 '''

#pymanopt/pymanopt/optimizers/nelder_mead.py  at 
import time

import numpy as np

import pymanopt
from pymanopt import tools
from pymanopt.optimizers.optimizer import Optimizer, OptimizerResult
from pymanopt.optimizers.steepest_descent import SteepestDescent


def compute_centroid(manifold, points, max_iterations=15, max_time=1000):
    """Compute the centroid of `points` on the `manifold` as Karcher mean."""

    @pymanopt.function.numpy(manifold)
    def objective(*y):
        if manifold.num_values == 1:
            (y,) = y
        return sum([manifold.dist(y, point) ** 2 for point in points]) / 2

    @pymanopt.function.numpy(manifold)
    def gradient(*y):
        if manifold.num_values == 1:
            (y,) = y
        return -sum(
            [manifold.log(y, point) for point in points],
            manifold.zero_vector(y),
        )

    optimizer = SteepestDescent(max_iterations=max_iterations, verbosity=2, max_time=max_time)
    print(optimizer._max_time)
    #optimizer = SteepestDescent(max_iterations=15, verbosity=0)
    problem = pymanopt.Problem(
        manifold, objective, riemannian_gradient=gradient
    )
    #return optimizer.run(problem).point
    return optimizer.run(problem)

def sample_random_local_basis(model, eval_config, full=False):
    torch.autograd.set_grad_enabled(True)
    rng = np.random.RandomState(eval_config.seed)
    
    local_basis_list = []
    subspace_dim = eval_config.subspace_dim
    for _ in tqdm(range(eval_config.n_samples)):
        noise, z, z_local_basis, z_sv, noise_basis, rank, noise_level = get_random_local_basis(model, rng,
                                                                            last_layer_name=eval_config.last_layer_name,
                                                                            rankEst=False,
                                                                            sv_thres_ratio=eval_config.sv_thres_ratio)
        if full:
            local_basis_list.append(z_local_basis.numpy())
        else:
            local_basis_list.append(z_local_basis[:, :subspace_dim].numpy())
    return local_basis_list

