from notebook_init import *
import time, os
import numpy as np
from tqdm import tqdm
import argparse
from pymanopt.manifolds import Grassmann
from frechet_mean_basis import sample_random_local_basis, compute_centroid

parser = argparse.ArgumentParser()
#parser.add_argument("-n", "--gpu_number", default = 0, type = int, help="gpu number")
parser.add_argument("-m", "--model_name", default = 'StyleGAN2', type = str, help="Model name")
parser.add_argument("-sv", "--sv_thres", default = 0.01, type = float, help="Singular Value Threshold")
parser.add_argument("-s", "--n_samples", default = 1000, type = int, help="Number of Samples for Frechet Mean")
args = parser.parse_args()
#os.environ['CUDA_VISIBLE_DEVICES']=str(args.gpu_number) 
torch.autograd.set_grad_enabled(True)

print(f"Computing Frechet Subspace for Model name : {args.model_name} SV Thres : {args.sv_thres}")

############################################################### 
'''Load Model'''
# StyleGAN2
if args.model_name == 'StyleGAN2':
    use_w = True
    dataset = 'ffhq'   #config-f
    #dataset = 'ffhq-config-e'
    inst = get_instrumented_model('StyleGAN2', dataset, 'style', device, inst=inst, use_w=use_w)
    model = inst.model
    model.truncation = 1.0
elif args.model_name == 'StyleGAN2-config-e':
    use_w = True
    dataset = 'ffhq-config-e'
    inst = get_instrumented_model('StyleGAN2', dataset, 'style', device, inst=inst, use_w=use_w)
    model = inst.model
    model.truncation = 1.0
elif args.model_name == 'StyleGAN1':
    ### StyleGAN1
    use_w = True
    dataset = 'ffhq'
    inst = get_instrumented_model('StyleGAN', dataset, 'g_mapping', device, use_w=use_w, inst=inst)
    model = inst.model
    model.truncation = 1.0 # NOT IMPLEMENTED
elif args.model_name == 'StyleGAN2-cat':
    ## StyleGAN2 - LSUN-CAT
    use_w = True
    dataset = 'cat'
    inst = get_instrumented_model('StyleGAN2', dataset, 'style', device, inst=inst, use_w=use_w)
    model = inst.model
    model.truncation = 1.0
elif args.model_name == 'StyleGAN1-cat':
    ### StyleGAN1 - LSUN-CAT
    use_w = True
    dataset = 'cats'
    inst = get_instrumented_model('StyleGAN', dataset, 'g_mapping', device, use_w=use_w, inst=inst)
    model = inst.model
    model.truncation = 1.0 # NOT IMPLEMENTED
else:
    print("Model name should be one of StyleGAN2, StyleGAN2-config-e, StyleGAN1, StyleGAN2-cat, StyleGAN1-cat")
    assert False

if "StyleGAN2" in args.model_name:
    subLayerNames = [str(a) for a in range(1, 9)][2:]
else:
    subLayerNames = [f'dense{a}_act' for a in range(0, 8)][2:]

class compare_basis_config:
    n_samples = 50
    seed = 0
    subspace_dim = 2
    rankEst = False
    sv_thres_ratio =  0.001
    last_layer_name = '8'

with open(f'./subnetwork_stats/LayerThres2rank_dict_{args.model_name}.dill', 'rb') as f:
    LayerThres2rank_dict = pickle.load(f)

eval_config = compare_basis_config()
eval_config.n_samples = args.n_samples
eval_config.sv_thres_ratio = args.sv_thres
eval_config.rankEst = False
return_samples = True
max_iter = 1000
max_time = 2000

print(f"Computing Frechet Mean Subspace for Layers : {subLayerNames}")
for last_layer_name in subLayerNames:
    print(f"Computing Frechet Mean Subspace in Layer {last_layer_name}")
    
    eval_config.last_layer_name = last_layer_name
    local_basis_list = sample_random_local_basis(model, eval_config, full=True)
    eval_config.subspace_dim = int(np.mean(LayerThres2rank_dict[eval_config.last_layer_name][eval_config.sv_thres_ratio]).round())
    
    ambient_dim = local_basis_list[0].shape[0]
    manifold = Grassmann(ambient_dim, eval_config.subspace_dim, k=1)
    timer = time.time()
    local_basis_crop_list = [local_basis[:, :eval_config.subspace_dim] for local_basis in local_basis_list]
    opt_result = compute_centroid(manifold, local_basis_crop_list, max_iterations=max_iter, max_time=max_time)
    frechet_mean = opt_result.point
    print(f"Computing Frechet mean for dim {eval_config.subspace_dim} took {round(time.time() - timer, 2)}")    
    np.save(f'./frechet_mean/frechet_{args.model_name}_layer_{last_layer_name}_thres_{args.sv_thres}_maxiter_{max_iter}.npy', frechet_mean)
    #np.save(f'./frechet_mean/frechet_{args.model_name}_layer_{last_layer_name}_thres_{args.sv_thres}_maxiter_{max_iter}_maxtime_{max_time}.npy', frechet_mean)
       
        
            