import numpy as np
from localbasis_utils import compute_grsm_metric
import matplotlib.pyplot as plt
import argparse
from notebook_init import *
import time, os
from tqdm import tqdm
from frechet_mean_basis import sample_random_local_basis, compute_centroid, align_to_reference, align_local_basis, compute_frechet_basis
from localbasis_utils import compare_basis_componentwise

parser = argparse.ArgumentParser()
#parser.add_argument("-n", "--gpu_number", default = 0, type = int, help="gpu number")
parser.add_argument("-sv", "--sv_thres", default = 0.01, type = float, help="Singular Value Threshold")
parser.add_argument("-s", "--n_samples", default = 200, type = int, help="Number of Samples for Frechet Mean")
parser.add_argument("--ref", default = 'gs', type = str, help="Reference Global Basis GS or SF")
args = parser.parse_args()
#os.environ['CUDA_VISIBLE_DEVICES']=str(args.gpu_number) 
assert args.ref in ['gs', 'sf']

model_name = 'StyleGAN2'
layer_name = '8'
if args.ref == 'gs':
    reference_global_basis = 'GANSpace'
else:
    reference_global_basis = 'SeFa'

fm_basis_path = f'./frechet_mean/aligned_frechet_basis_{model_name}_layer_{layer_name}_thres_{args.sv_thres}_maxiter_1000_basisiter_200_samples_{args.n_samples}.npy'
fm_basis = np.load(fm_basis_path) # shape : (ambient_dim, intrisic_layer_dim)
layer_dim = fm_basis.shape[1]

# gs_basis_path = f'./global_directions/ganspace_directions_ffhq_{model_name}_style-{layer_name}.npy'
# gs_basis = np.load(gs_basis_path).squeeze()[:layer_dim].transpose()
# print(fm_basis.shape, gs_basis.shape)

interps = np.load(f'./geodesic_interp_basis/interpSubspace_Frechet2{reference_global_basis}_{model_name}_layer_{layer_name}_thres_{args.sv_thres}_samples_{args.n_samples}_n_step_{7}_ovsht.npy')
print(f"Refining Interps from Frechet Mean to {reference_global_basis}")

print(model_name, layer_name)
subLayerNames = [str(a) for a in range(1, 9)]
#subLayerNames = [f'dense{a}_act' for a in range(0, 8)]
sv_thres_ratio_candi = [0.0005, 0.001, 0.005, 0.01]
print(subLayerNames)


# StyleGAN2
use_w = True
dataset = 'ffhq'   #config-f
#dataset = 'ffhq-config-e'
inst = get_instrumented_model('StyleGAN2', dataset, 'style', device, inst=inst, use_w=use_w)
model = inst.model
model.truncation = 1.0

class compare_basis_config:
    n_samples = 50
    seed = 0
    subspace_dim = 2
    rankEst = False
    sv_thres_ratio =  0.001
    last_layer_name = '8'

torch.autograd.set_grad_enabled(True)
eval_config = compare_basis_config()

#with open('./subnetwork_stats/LayerThres2rank_dict_StyleGAN2.dill', 'rb') as f:
with open(f'./subnetwork_stats/LayerThres2rank_dict_{model_name}.dill', 'rb') as f:
    LayerThres2rank_dict = pickle.load(f)
    
print(model_name)
print(f"Layer name : {layer_name}")
eval_config.sv_thres_ratio = args.sv_thres
eval_config.last_layer_name = layer_name
eval_config.n_samples = args.n_samples
eval_config.rankEst = False
max_iter = 200
max_time = 10000

# mean of Local Rank 
local_basis_list = sample_random_local_basis(model, eval_config, full=True)
eval_config.subspace_dim = int(np.mean(LayerThres2rank_dict[eval_config.last_layer_name][eval_config.sv_thres_ratio]).round())

for i, interp_subspace in enumerate(interps):
    if np.allclose(interp_subspace, fm_basis):
        print(f"(Interpolation Subspace = Frechet Mean Basis) at Idx {i}")
        aligned_frechet_global_basis_path = f'./geodesic_interp_basis/interpBasis_Frechet2{reference_global_basis}_StyleGAN2_layer_8_thres_{args.sv_thres}_samples_{args.n_samples}_n_step_7_ovsht_idx_{i}_basisiter_200.npy'
        np.save(aligned_frechet_global_basis_path, interp_subspace)
        continue
    
    frechet_subspace = interp_subspace
    timer = time.time()
    aligned_frechet_global_basis_amb =  compute_frechet_basis(frechet_subspace, local_basis_list, max_iterations=max_iter, max_time=max_time)
    aligned_frechet_global_basis_path = f'./geodesic_interp_basis/interpBasis_Frechet2{reference_global_basis}_StyleGAN2_layer_8_thres_{args.sv_thres}_samples_{args.n_samples}_n_step_7_ovsht_idx_{i}_basisiter_200.npy'
    print(f'Aligned frechet basis path : {aligned_frechet_global_basis_path}')
    print(f"Computing Frechet Mean Basis for dim {frechet_subspace.shape[1]} took {round(time.time() - timer, 2)}")    
    np.save(aligned_frechet_global_basis_path, aligned_frechet_global_basis_amb)