import torch
import torch.nn as nn
import random
from stylegan_model import StyledGenerator, Discriminator
def stylegan(stylegan_ckpt):
    # no_from_rgb_activate is False
    style_discriminator = nn.DataParallel(Discriminator(from_rgb_activate=True)).cuda()
    style_discriminator.eval()
    style_g_running = StyledGenerator(512).cuda()
    style_g_running.train(False)
    style_g_running.eval()
    stylegan_ckpt = torch.load(stylegan_ckpt)
    style_discriminator.module.load_state_dict(stylegan_ckpt['discriminator'])
    style_g_running.load_state_dict(stylegan_ckpt['g_running'])
    return style_g_running, style_discriminator 


def generating_data(style_g_running, style_discriminator, b_size):
    with torch.no_grad():
        gen_in11, gen_in12, gen_in21, gen_in22 = torch.randn( 4, b_size, 512, device='cuda').chunk(4, 0)
        gen_in1 = [gen_in11.squeeze(0), gen_in12.squeeze(0)]
        gen_in2 = [gen_in21.squeeze(0), gen_in22.squeeze(0)]
        Fake_img_1_1, G_feature_1_1 = style_g_running(gen_in1)
        _, D_feature_1_1 = style_discriminator(Fake_img_1_1, step=6, alpha=1, E1_output_feat=True)
        Fake_img_1_2, G_feature_1_2 = style_g_running(gen_in1, reference=True)
        _, D_feature_1_2 = style_discriminator(Fake_img_1_2, step=6, alpha=1, E1_output_feat=True)
        mix_begin = random.randint(2,6)# in starganv2, reference image only provide the style information, thus here we have to keep the gen_in1 offer the pose information (the bottom style code)
        mix_end = random.randint(mix_begin,6)
        Fake_img_1_mix_1_2, G_feature_1_mix_1_2 = style_g_running(gen_in1, mixing_range=(mix_begin, mix_end))
        _, D_feature_1_mix_1_2 = style_discriminator(Fake_img_1_mix_1_2, step=6, alpha=1, E1_output_feat=True)
        set1 = (gen_in1, Fake_img_1_1, G_feature_1_1, D_feature_1_1, Fake_img_1_2, G_feature_1_2, D_feature_1_2, Fake_img_1_mix_1_2, G_feature_1_mix_1_2, D_feature_1_mix_1_2)

        Fake_img_2_1, G_feature_2_1 = style_g_running(gen_in2)
        _, D_feature_2_1 = style_discriminator(Fake_img_2_1, step=6, alpha=1, E1_output_feat=True)
        Fake_img_2_2, G_feature_2_2 = style_g_running(gen_in2, reference=True)
        _, D_feature_2_2 = style_discriminator(Fake_img_2_2, step=6, alpha=1, E1_output_feat=True)
        mix_begin = random.randint(2,6)# in starganv2, reference image only provide the style information, thus here we have to keep the gen_in1 offer the pose information (the bottom style code)
        mix_end = random.randint(mix_begin,6)
        Fake_img_2_mix_1_2, G_feature_2_mix_1_2 = style_g_running(gen_in2, mixing_range=(mix_begin, mix_end))
        _, D_feature_2_mix_1_2 = style_discriminator(Fake_img_2_mix_1_2, step=6, alpha=1, E1_output_feat=True)

        set2 = (gen_in2, Fake_img_2_1, G_feature_2_1, D_feature_2_1, Fake_img_2_2, G_feature_2_2, D_feature_2_2, Fake_img_2_mix_1_2, G_feature_2_mix_1_2, D_feature_2_mix_1_2)
    return (set1, set2)
