'''Calculate Attribute using classifier'''
import numpy as np
import os, pickle, time, cv2, sys
from tqdm import tqdm
import argparse
import pandas as pd
from notebook_init import *
from localbasis_utils import get_mapping_network
sys.path.append('../')

parser = argparse.ArgumentParser()
parser.add_argument("-n", "--gpu_number", default = 0, type = int, help="gpu number")
parser.add_argument("-m", "--model_name", default = 'StyleGAN2', type = str, help="Model name")
parser.add_argument("-b", "--basis_name", default = 'Standard', type = str, help="Model name")
parser.add_argument("-sv", "--sv_thres", default = 0.01, type = float, help="Singular Value Threshold")
parser.add_argument("-s", "--n_samples", default = 200, type = int, help="Number of Samples for Frechet Mean")
args = parser.parse_args()
assert args.basis_name in ['Standard', 'Standard-refine', 'GANSpace', 'SeFa', 'Frechet_mean', 'Full', 'Interp-GS', 'Interp-SF']
assert args.model_name in ['StyleGAN2', 'StyleGAN2-config-e', 'StyleGAN1']

if args.basis_name == 'Interp': assert args.model_name == 'StyleGAN2'
if args.basis_name == 'SeFa': assert args.model_name == 'StyleGAN2'

os.environ['CUDA_VISIBLE_DEVICES']=str(args.gpu_number) 
torch.autograd.set_grad_enabled(True)

print(f"DCI Evaluation Model name : {args.model_name} Basis name : {args.basis_name}")
# out_root = Path('out/1dim')
# makedirs(out_root, exist_ok=True)
# rand = lambda : np.random.randint(np.iinfo(np.int32).max)
# random_state=5

################################################################## 
'''Config'''
num_img=10000
#num_img=1000
num_once=5
batch_size=5
print(f"DCI Evaluation Num Image : {num_img}")

classifer_path= './metrics_checkpoint/'
# in_path = './npy/ffhq'
# output_path= './npy/ffhq'
in_path = f'./dci/{args.model_name}'
output_path= in_path
resize = 256
use_w = True
if not os.path.isdir(output_path):
    os.mkdir(output_path)

############################################################### 
'''Load Model'''
# StyleGAN2
if args.model_name == 'StyleGAN2':
    use_w = True
    dataset = 'ffhq'   #config-f
    #dataset = 'ffhq-config-e'
    inst = get_instrumented_model('StyleGAN2', dataset, 'style', device, inst=inst, use_w=use_w)
    model = inst.model
    model.truncation = 1.0
elif args.model_name == 'StyleGAN2-config-e':
    use_w = True
    dataset = 'ffhq-config-e'
    inst = get_instrumented_model('StyleGAN2', dataset, 'style', device, inst=inst, use_w=use_w)
    model = inst.model
    model.truncation = 1.0
elif args.model_name == 'StyleGAN1':
    ### StyleGAN1
    use_w = True
    dataset = 'ffhq'
    inst = get_instrumented_model('StyleGAN', dataset, 'g_mapping', device, use_w=use_w, inst=inst)
    model = inst.model
    model.truncation = 1.0 # NOT IMPLEMENTED
else:
    print("Model name should be one of StyleGAN2, StyleGAN2-config-e, StyleGAN1")
    assert False

if "StyleGAN2" in args.model_name:
    subLayerNames = [str(a) for a in range(1, 9)][2:]
else:
    subLayerNames = [f'dense{a}_act' for a in range(0, 8)][2:]

############################################################### 
if not os.path.exists(output_path+'/Z.npy') and os.path.exists(in_path+'/Ws.npy'):
    '''Sample Latent'''
    sub_layers = []
    res_layers = []

    #latents = np.load(in_path+'/Z.npy')
    latents=np.zeros((num_img,512),dtype='float32')
    dlatents_layer = []
    for subnet_last_layer in subLayerNames:
        sub, res, noise_dim = get_mapping_network(inst.model, last_layer_name=subnet_last_layer)
        sub_layers.append(sub)
        res_layers.append(res)
        dlatents=np.zeros((num_img,512),dtype='float32')
        dlatents_layer.append(dlatents)

    for i in tqdm(range(int(num_img/num_once))):
        src_latents =  torch.randn(num_once, noise_dim).to(device)
        with torch.no_grad():
            for layer_num, [sub, res] in enumerate(zip(sub_layers, res_layers)):
                src_dlatents = sub(src_latents) # [seed, layer, component]
                src_dlatents = src_dlatents.detach().cpu().numpy().astype('float32')
                dlatents_layer[layer_num][(i*num_once):((i+1)*num_once),:]=src_dlatents
            src_latents = src_latents.detach().cpu().numpy().astype('float32')
            latents[(i*num_once):((i+1)*num_once),:] = src_latents
    tmp=output_path+'/Z.npy'
    print('Save Z')
    np.save(tmp,latents)    

    tmp=in_path+'/Ws.npy'
    print('Save Ws')
    np.save(tmp,dlatents_layer)
else:
    print("Latent Exists")

#######################################################################################
'''Generate Image'''

if not os.path.exists(output_path+'/images.npy'):
    print('Generate Image')
    tmp=output_path+'/Ws.npy'
    dlatents_layer = np.load(tmp)    
    dlatents = dlatents_layer[-1]

    all_images=[]
    with torch.no_grad():
        for i in tqdm(range(int(num_img/num_once))):
            tmp = dlatents[i*num_once: (i+1)*num_once]
            image2 = model(torch.tensor(tmp).to(device))
            image2 = image2.detach().cpu().numpy()
            if resize is not None:
                images = []
                for img in image2:
                    img = img.transpose(1, 2, 0)
                    img = cv2.resize(img, (resize,resize),interpolation = cv2.INTER_LANCZOS4)
                    img= np.array(img)
                    img = img.transpose(2, 0, 1)
                    images.append(img)
                
            all_images += images
        all_images = np.stack(all_images, axis = 0)
    print('Save Image')
    tmp=output_path+'/images.npy'
    np.save(tmp,all_images)
else:
    print("Generated Image Exists")

############################################################################################
'''Label Generated Image'''
import dnnlib.tflib as tflib
from DCI import DCI, Test, Train_and_Test

if not os.path.exists(output_path+'/attribute.csv'):
    imgs=np.load(output_path+'/images.npy')
    names=[name for name in os.listdir(classifer_path) if 'celebahq-classifier' in name]
    names.sort()

    tflib.init_tf()
    results={}
    for name in names:
        print(name)
        tmp = os.path.join(classifer_path,name)
        with open(tmp, 'rb') as f:
            classifier = pickle.load(f)
        logits=np.zeros(len(imgs))
        for i in tqdm(range(int(imgs.shape[0]/batch_size)), leave = False):
            tmp_imgs=imgs[(i*batch_size):((i+1)*batch_size)]
            tmp = classifier.run(tmp_imgs, None)
            tmp1=tmp.reshape(-1) 
            logits[(i*batch_size):((i+1)*batch_size)]=tmp1

        tmp1=name[20:-4]
        #print(tmp1)
        results[tmp1]=logits

        results2=pd.DataFrame(results)
        results2.to_csv(output_path+'/attribute.csv',index=False)
else:
    print("Label Annotation Exists")
    
############################################################################################
'''Calculate DCI score'''

class dummy:
    latent_path = os.path.join(output_path, 'Ws.npy')
    attribute_path =  os.path.join(output_path,'attribute.csv')
    save_path = os.path.join(output_path, f'DCI_W_sublayer_basis_{args.basis_name}')
dci_args = dummy

def get_converted_coord(dlatents, layer_name, basis_name, args):
    if basis_name == 'Standard': return dlatents
    #fm_basis_path = f'./frechet_mean/aligned_frechet_basis_{args.model_name}_layer_{layer_name}_thres_{args.sv_thres}_maxiter_1000_basisiter_30.npy'
    fm_basis_path = f'./frechet_mean/aligned_frechet_basis_{args.model_name}_layer_{layer_name}_thres_{args.sv_thres}_maxiter_1000_basisiter_200_samples_{args.n_samples}.npy'
    
    fm_basis = np.load(fm_basis_path) # shape : (ambient_dim, intrisic_layer_dim)
    layer_dim = fm_basis.shape[1]

    gs_basis_path = f'./global_directions/ganspace_directions_ffhq_{args.model_name}_style-{layer_name}.npy'
    gs_basis = np.load(gs_basis_path).squeeze()[:layer_dim].transpose()

    sf_basis_path = f'./global_directions/sefa_directions_ffhq_{args.model_name}.npy'
    sf_basis = np.load(sf_basis_path).squeeze()[:layer_dim].transpose()
    
    if basis_name == 'Standard-refine':
        ''' Choose max variance axis for Standard-refine '''
        axis_std = np.std(dlatents, axis=0)
        max_std_idx = get_topN_idx_sorted(axis_std, top_n = layer_dim)
        print(f'basis name : Standard-refine - idx = {max_std_idx}')
        return dlatents[:, max_std_idx]
    elif basis_name == 'GANSpace':
        print(f'basis name : {gs_basis_path}')
        return np.matmul(dlatents, gs_basis)
    elif basis_name == 'SeFa':
        print(f'basis name : {sf_basis_path}')
        return np.matmul(dlatents, sf_basis)
    elif 'Interp' in basis_name:
        idx = basis_name[-1]
        if 'Interp-GS' in basis_name:
            interp_basis_path = f'./geodesic_interp_basis/interpBasis_Frechet2GANSpace_StyleGAN2_layer_{layer_name}_thres_{args.sv_thres}_samples_{args.n_samples}_n_step_7_ovsht_idx_{idx}_basisiter_200.npy'
        else:
            interp_basis_path = f'./geodesic_interp_basis/interpBasis_Frechet2SeFa_StyleGAN2_layer_{layer_name}_thres_{args.sv_thres}_samples_{args.n_samples}_n_step_7_ovsht_idx_{idx}_basisiter_200.npy'
        interp_basis = np.load(interp_basis_path) 
        return np.matmul(dlatents, interp_basis)
    else:
        print(f'basis name : {fm_basis_path}')
        return np.matmul(dlatents, fm_basis)
        

attribute_loaded = pd.read_csv(dci_args.attribute_path)
dlatents_layer = np.load(dci_args.latent_path)

if args.basis_name == 'Full':
    if args.model_name == 'StyleGAN2':
        basis_names = ['Standard', 'Standard-refine', 'GANSpace', 'SeFa', 'Frechet_mean']
    else:
        basis_names = ['Standard', 'Standard-refine', 'GANSpace', 'Frechet_mean']
elif 'Interp' in args.basis_name:
    basis_names = [args.basis_name+f'_{i}' for i in range(9)]
else:
    basis_names = [args.basis_name]

import pickle
for basis_name in basis_names:
    dlatents_scores = []
    for i, dlatents in enumerate(dlatents_layer):
        if ('Interp' in args.basis_name) and (i < len(dlatents_layer)-1): continue  # Only the last layer is implemented
        if ('SeFa' in args.basis_name) and (i < len(dlatents_layer)-1): continue  # Only the last layer is implemented
        layer_name = subLayerNames[i]
        converted_latents = get_converted_coord(dlatents, layer_name, basis_name, args)
        print(f"Basis type {basis_name} Converted shape : {converted_latents.shape}")
        
        dci = DCI(converted_latents, attribute_loaded)
        scores = Train_and_Test(dci, dci_args.save_path +'_'+ str(i+1))
        dlatents_scores.append(scores)
    if basis_name != 'Standard':
        dci_args.save_path = os.path.join(output_path, f'DCI_W_sublayer_basis_{basis_name}_sv_{args.sv_thres}_basisiter_200_samples_{args.n_samples}')
    else:
        dci_args.save_path = os.path.join(output_path, f'DCI_W_sublayer_basis_{basis_name}')
    with open(dci_args.save_path, 'wb') as f:
        pickle.dump(dlatents_scores, f)  
