import sys

sys.path.append("..")
sys.path.append("../..")
sys.path.append("../../models/eg3d")

from loguru import logger
from tqdm import tqdm
import pickle
import numpy as np

import torch
from torch.nn.utils.rnn import pad_sequence
import torchvision

from models.third_party.BiSeNet import FaceParser
from models.third_party.utils import parse_indices_str, tuple_of_indices, tuple_of_type
from models.eg3d.torch_utils import misc
from models.eg3d.training.triplane import TriPlaneGenerator
from models.eg3d.training.camera_utils import FOV_to_intrinsics, LookAtPoseSampler
from criteria.parse_related_loss.unet import unet  # , semantic_regions


parsenet_weights = '../../pretrained_models/parsenet.pth'
face_parser_ckpt = '../../pretrained_models/BiSeNet.pth'
stylegan_weights = '../../pretrained_models/ffhqrebalanced512-128.pkl'
correlation_PATH = '../../pretrained_models/correlation_w.pt'

device = 'cuda'
fov_deg = 18.837
avg_camera_pivot = [0, 0, 0.2]
intrinsics = FOV_to_intrinsics(fov_deg, device=device)
pitch_range = 0.25
yaw_range = 0.35
cam_pivot = torch.Tensor(avg_camera_pivot).to(device)
cam_radius = 2.7
truncation_psi = 0.7
truncation_cutoff = 14

def main(num_samples=10000, batch_size=1):
    # G = ImprovedStyleGAN2Generator.load(checkpoint, device=device, default_truncation=truncation)
    # G.manipulation_mode()

    with open(stylegan_weights, 'rb') as f:
        pretrained_generator = pickle.load(f)['G_ema']
    G = TriPlaneGenerator(*pretrained_generator.init_args, **pretrained_generator.init_kwargs).requires_grad_(False).to(device)
    misc.copy_params_and_buffers(pretrained_generator, G, require_all=True)
    G.neural_rendering_resolution = pretrained_generator.neural_rendering_resolution
    G.rendering_kwargs = pretrained_generator.rendering_kwargs
    G.requires_grad_(True)
    del pretrained_generator
    logger.info("load generator over")

    face_parser = FaceParser(model_path=face_parser_ckpt, device=device)
    # face_parser = unet()
    # face_parser.load_state_dict(torch.load(parsenet_weights))
    # face_parser.eval().cuda()
    logger.info("load face_parser over")
    
    num_batch = (num_samples + batch_size - 1) // batch_size
    batch_id = num_batch

    style_grads = None
    style_grad_num = None

    pbar = tqdm(total=num_batch, ncols=0)

    while batch_id > 0:
        # z = torch.randn(batch_size, G.z_dim, device=device)
        # w = G.z_to_w(z)
        # styles = G.w_to_styles(w)
        
        z = torch.randn(1, G.z_dim).to(device)
        cam2world_pose = LookAtPoseSampler.sample(3.14/2 + yaw_range * np.sin(2 * 3.14 * np.random.rand()),
                                                3.14/2 -0.05 + pitch_range * np.cos(2 * 3.14 * np.random.rand()),
                                                cam_pivot, radius=cam_radius, device=device)
        c = torch.cat([cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1).to(device)

        w = G.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff).to(device)
        s = G.get_styles(w)
        styles = pad_sequence(s, batch_first=True, padding_value=0)[None]
        # styles = w

        styles = styles.detach().requires_grad_(True)
        images = G.synthesis(styles, c)['image']
        # styles = [s.detach().requires_grad_(True) for s in styles]
        # images = G.styles_to_image(styles)

        # torchvision.utils.save_image(torch.cat([test.detach().cpu()]), 'test.png',
		# 							 normalize=True, scale_each=True, range=(-1, 1), nrow=3)

        with torch.no_grad():
            parsing = face_parser.batch_run(images, pre_normalize=True, image_repr=False, compact_mask=True)
            if parsing is None:
                continue
            
            # labels_predict = face_parser(images)
            # labels_predict = torch.softmax(labels_predict, 1)
            # parsing = torch.cat([labels_predict[:, ids].sum(dim=1, keepdim=True) for ids in semantic_regions.values()], dim=1).clip(0, 1)
            # # parsing = torch.unsqueeze(torch.max(labels_predict, 1)[1], 1)
            # pred_idx = parsing.argmax(dim=1, keepdim=True)

        if style_grads is None:
            style_grads = [[torch.zeros(s.size(-1), device=device) for _ in range(parsing.size(1))] for s in styles[0]]
            style_grad_num = [[0 for _ in range(parsing.size(1))] for _ in styles[0]]
            # style_grads = [torch.zeros(styles.size(-1), device=device) for _ in range(parsing.size(1))]
            # style_grad_num = [0 for _ in range(parsing.size(1))]

        for mask_id in range(parsing.size(1)):
            G.zero_grad()
            styles.grad = None
            # for s in styles:
            #     s.grad = None
            grad_map = parsing[:, [mask_id, ]].repeat(1, 3, 1, 1).float()
            grad_map /= grad_map.abs().sum(dim=[1, 2, 3], keepdim=True).clip_(1e-5)
            
            # grad_map = torch.where(pred_idx == mask_id, 1, 0).repeat(1, 3, 1, 1).float()

            # some mask result may not contains any content, e.g. full of 0.
            num_valid = (grad_map.sum(dim=[-1, -2, -3]) > 0).sum()
            images.backward(grad_map, retain_graph=True)

            for i, s_grad in enumerate(styles.grad[0]):
                style_grads[i][mask_id] += s_grad.abs()  # .sum(dim=[0])
                style_grad_num[i][mask_id] += num_valid

        batch_id -= 1
        pbar.update(1)
    pbar.close()

    channel_correlation = []
    print(','.join(map(str, [float(c) / (num_batch * batch_size) for c in style_grad_num[0]])))
    for layer in range(len(style_grads)):
        channel_correlation.append(torch.stack([c.div(n) for c, n in zip(style_grads[layer], style_grad_num[layer])]))
    torch.save(channel_correlation, correlation_PATH)
   

def channel_selector(semantic_idx=10, num_channels=1024, correlation_path=correlation_PATH):
    correlation = torch.load(correlation_path)
    corr = torch.stack(correlation, 0).abs().transpose(0, 1)
    # ric = corr / corr.sum(dim=1, keepdim=True)
    rim = corr / corr.sum(dim=0, keepdim=True).clamp(0.000001)
    rim[:, 1:20:3, :] = 0
    rim[:, 20:, :] = 0
    
    values, indices = rim[semantic_idx].reshape(-1).topk(num_channels)
    layer_idx = indices//512
    channel_idx = indices%512
    return layer_idx, channel_idx
    
    
def _channel_selector(num_layers=26, rules='ric[10]>0.1', is_indexes_rule=False, correlation_path=correlation_PATH) -> dict:
    """
    select channels according to rules.
    :param layers: selected layers.
    :param rules: python command that will be run with eval(). something like "ric[10]>0.1"
    or "(ric[10]>0.1)&(ric[8]>0.1)"
    :param is_indexes_rule: if set as True, directly convert rules as channel indices, like "12,1,511"
    :param correlation_path: channel-region correlations.
    :return: dict(layer1=mask1, layer2=mask2)
    """
    correlation = torch.load(correlation_path)

    layers = tuple(np.arange(0, num_layers))
    if isinstance(rules, int) and is_indexes_rule:
        rules = f"{rules}"
    rules = tuple_of_type(rules, str)
    if len(rules) == 1:
        rules = rules * len(layers)
    assert len(layers) == len(rules)

    result = {}
    for layer, rule in zip(layers, rules):
        corr = correlation[layer].abs()
        # relative correlation by mask
        rcm = corr / corr.amax(dim=1, keepdim=True)  # noqa
        # relative correlation by channel
        rcc = corr / corr.amax(dim=0, keepdim=True)  # noqa
        # relative importance by channel
        ric = corr / corr.sum(dim=0, keepdim=True)  # noqa
        if is_indexes_rule:
            mask = torch.zeros_like(corr[0])
            for c in parse_indices_str(rule):
                mask[c] = 1
        else:
            if rule != "all":
                # rule from user, very very dangerous!!
                mask = eval(rule)
                # mask = ric[semantic_idx]>0.1
            else:
                mask = torch.ones_like(corr[0])
        logger.info(f"layer {layer} {int(mask.sum())} dims: {torch.nonzero(mask.float()).flatten().tolist()[:10]}")
        assert mask.size() == torch.Size([corr.size(1)])
        result[layer] = mask

    return result


if __name__ == '__main__':
    main()
    # mask = channel_selector()