from lib2to3.pgen2.pgen import generate_grammar
from notebook_init import *
import argparse
from localbasis_utils import efficient_get_random_local_basis, get_mapping_network
from torch import nn

parser = argparse.ArgumentParser()

parser.add_argument('--n_sample', type=int, default=50000)
parser.add_argument('--perturb', type=float, default=3)
parser.add_argument('--comp', type=int, default=0)
parser.add_argument('--sublayer_str', type=str, default='dense7_act')
parser.add_argument('--sublayer_num', type=int, default=3)
parser.add_argument('--batch_size', type=int, default=4)
parser.add_argument('--generator', type=str, default='stylegan1')

args = parser.parse_args()
torch.autograd.set_grad_enabled(True)
rand = lambda : np.random.randint(np.iinfo(np.int32).max)

# load model and ganspace comp
if args.generator=='stylegan1':
    use_w = True
    dataset = 'ffhq'
    inst = get_instrumented_model('StyleGAN', dataset, 'g_mapping', device, use_w=use_w, inst=inst)
    model = inst.model
    model.truncation = 1.0 # NOT IMPLEMENTED
elif args.generator=='stylegan2':
    use_w = True
    dataset = 'ffhq'
    inst = get_instrumented_model('StyleGAN2', dataset, 'style', device, inst=inst, use_w=use_w)
    model = inst.model
    model.truncation = 1.0
    model = model.to(device)
else:
    raise NotImplementedError('Generator Wrong!!')

#load inception
from calc_inception import load_patched_inception_v3
from torch import nn
from tqdm import tqdm

torch.autograd.set_grad_enabled(True)

inception = load_patched_inception_v3().to(device)
inception.eval()

#config
perturb_intensity = args.perturb
ncomp = args.comp
nsample = args.n_sample
# iter_step = args.it_step
batch_size = args.batch_size
num_seeds = nsample//batch_size if nsample%batch_size==0 else nsample//batch_size + 1
seeds_ffhq = [rand() for _ in range(num_seeds)]

out_root = Path('out/fid/iter')
makedirs(out_root, exist_ok=True)
# gs_path = out_root / f'GANSpace_i{ncomp}_ptb{perturb_intensity}'
# lb_path = out_root / f'Compass_i{ncomp}_ptb{perturb_intensity}'
# sf_path = out_root / f'SeFa_i{ncomp}_ptb{perturb_intensity}'
# it_path = out_root / f'iterative_i{ncomp}_ptb{perturb_intensity}'
# makedirs(out_root, exist_ok=True)
# makedirs(gs_path, exist_ok=True)
# makedirs(lb_path, exist_ok=True)
# makedirs(it_path, exist_ok=True)
# makedirs(sf_path, exist_ok=True)

# on_manifold_features = []
gs_features = []
lb_features = []
#sf_features = []
#it_features = []

# ['dense2_act', 'dense3_act', 'dense4_act', 'dense5_act',  'dense6_act', 'dense7_act']
if args.generator=='stylegan1':
    subLayerNames= [str(args.sublayer_str)]#['5']
elif args.generator == 'stylegan2':
    subLayerNames= [str(args.sublayer_num)]

print(f"Sub layer name : {subLayerNames}")
#sv_dict = {'dense2_act':2.24648825, 'dense3_act':2.36979631, 'dense4_act':2.31150035, 'dense5_act':2.51148924,  'dense6_act':2.351194, 'dense7_act':1.65715194}
print(args)
for last_layer_name in subLayerNames:
    mapping_network, res , noise_dim = get_mapping_network(model, last_layer_name)
    for en_idx, seed in enumerate(tqdm(seeds_ffhq)):
        ## Original Local Basis
        rng = np.random.RandomState(seed)
        if en_idx<nsample//batch_size:
            noise_dim, b = 512, batch_size
            noise = torch.from_numpy(
                    rng.standard_normal(noise_dim * b)
                    .reshape(b, noise_dim)).to(device)
        elif en_idx==nsample//batch_size:
            noise_dim, b = 512, nsample%batch_size
            noise = torch.from_numpy(
                    rng.standard_normal(noise_dim * b)
                    .reshape(b, noise_dim)).to(device)

        z, z_local_basis, z_sv = efficient_get_random_local_basis(mapping_network, None, noise=noise, last_layer_name=last_layer_name,
                                                 rankEst=False, sv_thres_ratio=0.005, alpha=0.1, noise_dim=noise_dim)
        
        _, z, z_local_basis, z_sv, _, _, _ = get_random_local_basis(model, random_state, noise = None, last_layer_name=None,
                           rankEst=False, sv_thres_ratio=0.005, alpha=0.1)
        
        if args.generator=='stylegan1': # -{int(args.sublayer_num)}
            gs_dir = np.load(f"/home/hscho/newvae/stylegan1_global_basis/stylegan-ffhq_g_mapping.{args.sublayer_str.split('_')[0]}_ipca_c512_n1000000.npy")
        elif args.generator=='stylegan2':
            gs_dir = np.load(f"/home/hscho/newvae/ganspace_ffhq_stylegan2_subnetwork/ganspace_directions_ffhq_stylegan2_style-{int(args.sublayer_num)}.npy")
        #np.load(f"/home/hscho/newvae/stylegan1_global_basis/ipca_sg1_512_{args.sublayer_str}.npy")
        #np.load(f"/home/hscho/newvae/ganspace_ffhq_stylegan2_subnetwork/ganspace_directions_ffhq_stylegan2_style-{int(args.sublayer_num)}.npy")
        gs_dir = torch.from_numpy(gs_dir).to(device) 
        local_lat_comp = z_local_basis.transpose(1,2).detach().to(device)
        local_lat_mean = z.detach().to(device)
        #get perturbed w
        # sv_dict[args.sublayer_str] *
        lb_w = local_lat_mean + perturb_intensity * local_lat_comp[:,0,:]#* (local_lat_comp.squeeze()[:rank] * sampled_zs).sum(dim=0).view(1,-1) / torch.sqrt(torch.tensor(rank))
        gs_w = local_lat_mean + perturb_intensity * gs_dir.squeeze()[0]# (gs_dir.squeeze()[:rank] * sampled_zs).sum(dim=0).view(1,-1) / torch.sqrt(torch.tensor(rank))

        lb_img = model(res(lb_w.float())) # .forward
        gs_img = model(res(gs_w.float()))
        lb_feat = inception(lb_img)[0].view(lb_img.shape[0], -1)
        gs_feat = inception(gs_img)[0].view(gs_img.shape[0], -1)
        lb_features.append(lb_feat.detach().to('cpu'))
        gs_features.append(gs_feat.detach().to('cpu'))
        # z_sv = z_sv.to(device)
        # z_sv_prod = z_sv[:,:rank].prod(dim=1).view(-1,1)
        # z_sv_prods.append(z_sv_prod)
        # noises.append(noise)
        #on_manifold_features.append(on_manifold_feat.detach().to('cpu'))
        #lb_features.append(lb_feat.detach().to('cpu'))
        #gs_features.append(gs_feat.detach().to('cpu'))
        #print(z_sv[:10], z_sv_prod)
        del  z,  gs_img,  gs_feat, lb_img, lb_feat, z_sv, z_local_basis #sampled_zs z_sv, z_local_basis, lb_img, lb_feat,
        torch.cuda.empty_cache()

lb_features = torch.cat(lb_features, 0)
gs_features = torch.cat(gs_features, 0)

if args.generator == 'stylegan1':
    sublayer = args.sublayer_str
elif args.generator == 'stylegan2':
    sublayer = args.sublayer_num

torch.save(gs_features, out_root / f'gssn_features_comp{args.comp}_ptb{perturb_intensity:.1f}_n{nsample}_sublayer{sublayer}_{args.generator}.pt')
torch.save(lb_features, out_root / f'lbsn_features_comp{args.comp}_ptb{perturb_intensity:.1f}_n{nsample}_sublayer{sublayer}_{args.generator}.pt')
#torch.save(sf_features, out_root / f'sf_features_i{ncomp}_ptb{perturb_intensity}_n{nsample}_step{iter_step}.pt')
#torch.save(it_features, out_root / f'it_features_i{ncomp}_ptb{perturb_intensity}_n{nsample}_step{iter_step}.pt')


