from notebook_init import *
import time, os
import numpy as np
from tqdm import tqdm
import argparse
from pymanopt.manifolds import Grassmann
from frechet_mean_basis import sample_random_local_basis, compute_centroid, compute_frechet_basis

parser = argparse.ArgumentParser()
#parser.add_argument("-n", "--gpu_number", default = 0, type = int, help="gpu number")
parser.add_argument("-m", "--model_name", default = 'StyleGAN2', type = str, help="Model name")
parser.add_argument("-sv", "--sv_thres", default = 0.01, type = float, help="Singular Value Threshold")
parser.add_argument("-s", "--n_samples", default = 200, type = int, help="Number of Samples for Frechet Mean")
parser.add_argument("-l", "--layer", default = 'all', type = str, help="Layer name")
args = parser.parse_args()
#os.environ['CUDA_VISIBLE_DEVICES']=str(args.gpu_number) 
torch.autograd.set_grad_enabled(True)

print(f"Computing Frechet Subspace for Model name : {args.model_name} SV Thres : {args.sv_thres}")

############################################################### 
'''Load Model'''
# StyleGAN2
if args.model_name == 'StyleGAN2':
    use_w = True
    dataset = 'ffhq'   #config-f
    #dataset = 'ffhq-config-e'
    inst = get_instrumented_model('StyleGAN2', dataset, 'style', device, inst=inst, use_w=use_w)
    model = inst.model
    model.truncation = 1.0
elif args.model_name == 'StyleGAN2-config-e':
    use_w = True
    dataset = 'ffhq-config-e'
    inst = get_instrumented_model('StyleGAN2', dataset, 'style', device, inst=inst, use_w=use_w)
    model = inst.model
    model.truncation = 1.0
elif args.model_name == 'StyleGAN1':
    ### StyleGAN1
    use_w = True
    dataset = 'ffhq'
    inst = get_instrumented_model('StyleGAN', dataset, 'g_mapping', device, use_w=use_w, inst=inst)
    model = inst.model
    model.truncation = 1.0 # NOT IMPLEMENTED
elif args.model_name == 'StyleGAN2-cat':
    ## StyleGAN2 - LSUN-CAT
    use_w = True
    dataset = 'cat'
    inst = get_instrumented_model('StyleGAN2', dataset, 'style', device, inst=inst, use_w=use_w)
    model = inst.model
    model.truncation = 1.0
elif args.model_name == 'StyleGAN1-cat':
    ### StyleGAN1 - LSUN-CAT
    use_w = True
    dataset = 'cats'
    inst = get_instrumented_model('StyleGAN', dataset, 'g_mapping', device, use_w=use_w, inst=inst)
    model = inst.model
    model.truncation = 1.0 # NOT IMPLEMENTED
else:
    print("Model name should be one of StyleGAN2, StyleGAN2-config-e, StyleGAN1, StyleGAN2-cat, StyleGAN1-cat")
    assert False

if "StyleGAN2" in args.model_name:
    subLayerNames = [str(a) for a in range(1, 9)][2:]
else:
    subLayerNames = [f'dense{a}_act' for a in range(0, 8)][2:]
if args.layer != 'all':
    assert args.layer in subLayerNames
    subLayerNames = [args.layer]    

class compare_basis_config:
    n_samples = 50
    seed = 0
    subspace_dim = 2
    rankEst = False
    sv_thres_ratio =  0.001
    last_layer_name = '8'

eval_config = compare_basis_config()
eval_config.n_samples = args.n_samples
eval_config.sv_thres_ratio = args.sv_thres
eval_config.rankEst = False
return_samples = True
max_iter = 200
max_time = 10000

print(f"Computing Frechet Mean Basis for Layers : {subLayerNames}")
for last_layer_name in subLayerNames:
    print(f"Computing Frechet Mean Basis in Layer {last_layer_name} for {args.n_samples} samples")
    eval_config.last_layer_name = last_layer_name
    local_basis_list = sample_random_local_basis(model, eval_config, full=True)
    frechet_subspace_path = f'./frechet_mean/frechet_{args.model_name}_layer_{last_layer_name}_thres_{args.sv_thres}_maxiter_1000.npy'
    frechet_subspace = np.load(frechet_subspace_path)
    print(f'Frechet mean path : {frechet_subspace_path}')

    timer = time.time()
    aligned_frechet_global_basis_amb =  compute_frechet_basis(frechet_subspace, local_basis_list, max_iterations=max_iter, max_time=max_time)
    aligned_frechet_global_basis_path = f'./frechet_mean/aligned_frechet_basis_{args.model_name}_layer_{last_layer_name}_thres_{args.sv_thres}_maxiter_1000_basisiter_{max_iter}_samples_{args.n_samples}.npy'
    print(f'Aligned frechet basis path : {aligned_frechet_global_basis_path}')
    print(f"Computing Frechet Mean Basis for dim {frechet_subspace.shape[1]} took {round(time.time() - timer, 2)}")    
    np.save(aligned_frechet_global_basis_path, aligned_frechet_global_basis_amb)
    