from notebook_init import *
import argparse
from localbasis_utils import get_mapping_network, old_efficient_get_random_local_basis
from torch import nn

parser = argparse.ArgumentParser()

parser.add_argument('--n_sample', type=int, default=50000)
parser.add_argument('--perturb', type=float, default=3)
parser.add_argument('--comp', type=int, default=0)
parser.add_argument('--sublayer_str', type=str, default='dense7_act')
parser.add_argument('--sublayer_num', type=int, default=3)
#parser.add_argument('--batch_size', type=int, default=4)
parser.add_argument('--generator', type=str, default='stylegan1')

args = parser.parse_args()
torch.autograd.set_grad_enabled(True)
rand = lambda : np.random.randint(np.iinfo(np.int32).max)

# load model and ganspace comp
if args.generator=='stylegan1':
    use_w = True
    dataset = 'ffhq'
    inst = get_instrumented_model('StyleGAN', dataset, 'g_mapping', device, use_w=use_w, inst=inst)
    model = inst.model
    model.truncation = 1.0 # NOT IMPLEMENTED
elif args.generator=='stylegan2':
    use_w = True
    dataset = 'ffhq'
    inst = get_instrumented_model('StyleGAN2', dataset, 'style', device, inst=inst, use_w=use_w)
    model = inst.model
    model.truncation = 1.0
    model = model.to(device)
else:
    raise NotImplementedError('Generator Wrong!!')

#load inception
from calc_inception import load_patched_inception_v3
from torch import nn
from tqdm import tqdm

torch.autograd.set_grad_enabled(True)

inception = load_patched_inception_v3().to(device)
inception.eval()

#config
perturb_intensity = args.perturb
ncomp = args.comp
nsample = args.n_sample
batch_size = 1
num_seeds = nsample
seeds_ffhq = [rand() for _ in range(num_seeds)]

out_root = Path('out/fid/iter')
makedirs(out_root, exist_ok=True)

gs_features = []
lb_features = []

if args.generator=='stylegan1':
    subLayerNames= [str(args.sublayer_str)]
elif args.generator == 'stylegan2':
    subLayerNames= [str(args.sublayer_num)]

print(f"Sub layer name : {subLayerNames}")
#sv_dict = {'dense2_act':2.24648825, 'dense3_act':2.36979631, 'dense4_act':2.31150035, 'dense5_act':2.51148924,  'dense6_act':2.351194, 'dense7_act':1.65715194}
print(args)
for last_layer_name in subLayerNames:
    mapping_network, res , noise_dim = get_mapping_network(model, last_layer_name)
    for en_idx, seed in enumerate(tqdm(seeds_ffhq)):
        ## Original Local Basis
        rng = np.random.RandomState(seed)
        noise_dim, b = 512, 1
        noise = torch.from_numpy(
                rng.standard_normal(noise_dim * b)
                .reshape(b, noise_dim)).to(device)

        z, z_local_basis, z_sv = old_efficient_get_random_local_basis(mapping_network, None, noise=noise, last_layer_name=last_layer_name,
                                                 rankEst=False, sv_thres_ratio=0.005, alpha=0.1, noise_dim=noise_dim)
        
        if args.generator=='stylegan1': 
            gs_dir = np.load(f"./stylegan1_global_basis/stylegan-ffhq_g_mapping.{args.sublayer_str.split('_')[0]}_ipca_c512_n1000000.npy")
        elif args.generator=='stylegan2':
            gs_dir = np.load(f"./ganspace_ffhq_stylegan2_subnetwork/ganspace_directions_ffhq_stylegan2_style-{int(args.sublayer_num)}.npy")
        gs_dir = torch.from_numpy(gs_dir).to(device) 
        local_lat_comp = z_local_basis.transpose(1,2).detach().to(device)
        local_lat_mean = z.detach().to(device)
        #get perturbed w
        # sv_dict[args.sublayer_str] *
        lb_w = local_lat_mean + perturb_intensity * local_lat_comp[:,0,:]#* (local_lat_comp.squeeze()[:rank] * sampled_zs).sum(dim=0).view(1,-1) / torch.sqrt(torch.tensor(rank))
        gs_w = local_lat_mean + perturb_intensity * gs_dir.squeeze()[0].view(1,512)# (gs_dir.squeeze()[:rank] * sampled_zs).sum(dim=0).view(1,-1) / torch.sqrt(torch.tensor(rank))

        lb_img = model(res(lb_w.float())) # .forward
        gs_img = model(res(gs_w.float()))
        lb_feat = inception(lb_img)[0].view(lb_img.shape[0], -1)
        gs_feat = inception(gs_img)[0].view(gs_img.shape[0], -1)
        lb_features.append(lb_feat.detach().to('cpu'))
        gs_features.append(gs_feat.detach().to('cpu'))
        del z,  gs_img,  gs_feat, lb_img, lb_feat, z_sv, z_local_basis 
        torch.cuda.empty_cache()

lb_features = torch.cat(lb_features, 0)
gs_features = torch.cat(gs_features, 0)

if args.generator == 'stylegan1':
    sublayer = args.sublayer_str
elif args.generator == 'stylegan2':
    sublayer = args.sublayer_num

torch.save(gs_features, out_root / f'gssn_features_comp{args.comp}_ptb{perturb_intensity:.1f}_n{nsample}_sublayer{sublayer}_{args.generator}.pt')
torch.save(lb_features, out_root / f'lbsn_features_comp{args.comp}_ptb{perturb_intensity:.1f}_n{nsample}_sublayer{sublayer}_{args.generator}.pt')


