import numpy as np
import pandas as pd
import os
import torch
import pickle
from torchvision.utils import save_image
import sys
from tqdm import tqdm
path = sys.argv[1]
device = 'cpu'
target_fname = '01009.png'
target_id = os.path.basename(target_fname)
df = pd.read_csv('ffhq_aug_labels_boss300.csv')
ids = df.to_numpy()[:, 0]
values = df.to_numpy()[:, 1:]
target_attr = values[ids == (target_id)]
target_attr = torch.tensor(target_attr.astype(np.float32)).to(device).view(1, -1)
data = np.load('z_opt.npz')
print(data['z'].shape)
with open(path, 'rb') as f:
    G = pickle.load(f)['G_ema'].to(device)  # torch.nn.Module
ws = G.mapping(torch.tensor(data['z']).to(device).squeeze(), target_attr)
new_ws = torch.stack([w[0] for w in ws], 0).unsqueeze(0)
recon_img = G.synthesis(new_ws, force_fp32=True)


def normalize_2nd_moment(x, dim=1, eps=1e-8):
    return x * (x.square().mean(dim=dim, keepdim=True) + eps).rsqrt()

def generate_ws(self, z, c, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False,
            sparse_loss=False, label=None):
    with (torch.autograd.profiler.record_function('input')):
        if self.z_dim > 0:
            zs = []
            for i in range(self.ori_c_dim):
                zs = zs + [normalize_2nd_moment(z[:, i * self.i_dim:(i + 1) * self.i_dim])] * self.causal_nodes[i]
            zs = torch.stack(zs, 1)
            assert zs.size() == (len(z), self.c_dim, self.i_dim)
            all_activations = self.activate_mlp_net(zs)
            all_deactivations = self.deactivate_mlp_net(zs)
            all_nulls = self.null_mlp_net(zs)
            assert all_activations.size() == all_deactivations.size() == (len(z), self.c_dim, self.i_dim)
            z_out = []
            # merge activations
            importance = self.importance * (self.importance > 0.1)
            for i in range(self.ori_c_dim):
                start_dim = self.cum_causal_nodes[i]
                end_dim = self.cum_causal_nodes[i + 1]
                cur_activations = all_activations[:, start_dim:end_dim]  # B,T, I
                cur_deactivations = all_deactivations[:, start_dim:end_dim]  # B,T,I
                cur_labels = c[:, start_dim:end_dim].unsqueeze(-1)  # B,T,1
                # assert torch.max(torch.sum(cur_labels, dim=1).view(-1)) <= 1
                cur_act = torch.sum(cur_activations * cur_labels, dim=1)  # B,I
                cur_deact = torch.mean(cur_deactivations * torch.ones_like(cur_labels), dim=1)  # B,I
                cur_imp = importance[:, i]  # 1,I
                assert cur_imp.size() == (1, self.i_dim)
                cur_ci = (torch.sum(cur_labels, dim=1) > 0).float()
                assert cur_ci.size() == (len(c), 1)
                cur_z_out = cur_act * cur_ci + cur_deact * (1 - cur_ci)
                cur_z_out = cur_z_out * cur_imp + all_nulls[:, i] * (1 - cur_imp)
                assert cur_z_out.size() == (len(z), self.i_dim)
                z_out.append(cur_z_out)

            z_out = torch.cat(z_out, 1)
            z_main = self.main_mlp_net(normalize_2nd_moment(z[:, self.ori_c_dim * self.i_dim:]))
    x = torch.cat([z_out, z_main], 1).to(torch.float32)

    # Update moving average of W.
    if self.w_avg_beta is not None and self.training and not skip_w_avg_update:
        with torch.autograd.profiler.record_function('update_w_avg'):
            self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta))

    # Broadcast.
    if self.num_ws is not None:
        with torch.autograd.profiler.record_function('broadcast'):
            x = x.unsqueeze(1).repeat([1, self.num_ws, 1])

    # Apply truncation.
    if truncation_psi != 1:
        with torch.autograd.profiler.record_function('truncate'):
            assert self.w_avg_beta is not None
            if self.num_ws is None or truncation_cutoff is None:
                x = self.w_avg.lerp(x, truncation_psi)
            else:
                x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi)

    if sparse_loss:
        assert importance.size() == (1, self.ori_c_dim, self.i_dim)
        reweight_importance = importance * (self.causal_weights.view(1, -1, 1))
        return x, torch.mean(torch.abs(reweight_importance)), torch.mean(z_out ** 2), torch.sum(
            importance > 0.1).item(), torch.mean(torch.abs(z_out)).item(), torch.mean(torch.abs(z_main)).item()
    else:
        return x


images = []
for i in tqdm(range(134)):
    new_attrs = target_attr.clone()
    new_attrs[0, i] = 1 - new_attrs[0, i]
    ws = generate_ws(G.mapping, torch.tensor(data['z']).to(device).squeeze(), new_attrs)
    new_ws = torch.stack([w[0] for w in ws], 0).unsqueeze(0)
    img = G.synthesis(new_ws, force_fp32=True)
    images.append(recon_img)
    images.append(img)
images = torch.cat(images, 0)

save_image((images+1)/2, 'example.png', nrow=16, normalize=False)
