import torch
import torch.nn.functional as F
from torch.utils import data
from torchvision.datasets import ImageFolder
from torchvision import transforms
from torchvision.utils import save_image
from utilstrain.utils import idx2image


def generate_matrix(discriminator, generator, style_encoder, images, image_path, mix_idx=3, cut_idx=6, n_embed=[5,7,9]):
    with torch.no_grad():
        
        try: generator.vqs.eval()
        except: generator.gen_256.vqs.eval()

        batch = images.shape[0]
        features = discriminator.extract(images)
        style_vectors = style_encoder(features)
        rec_img, _, embed_idxs = generator.decode(style_vectors)
        
        image_grid = [ torch.ones(1, 3, images.shape[2], images.shape[3]) ]
        image_grid.append( images.cpu() )
        image_grid.append( torch.ones(1, 3, images.shape[2], images.shape[3]) )
        image_grid.append(rec_img.detach().cpu())
        
        for eidx in range(len(embed_idxs)):
            seg_map = F.interpolate( idx2image(embed_idxs[eidx], num_colors=n_embed[eidx]) , images.shape[3] ) 
            image_grid.append( torch.ones(1, 3, images.shape[2], images.shape[3]) )
            image_grid.append(seg_map.detach().cpu())


        for i in range(len(style_vectors)):
            image_grid.append( rec_img[i].unsqueeze(0).detach().cpu() )
            curr_vectors_tmp = style_vectors[i].unsqueeze(0).repeat(batch, 1, 1)
            curr_vectors_content = curr_vectors_tmp[:,:mix_idx]
            curr_vectors_mid = style_vectors[:,mix_idx:cut_idx]
            curr_vectors_style = curr_vectors_tmp[:,cut_idx:]
            
            curr_vectors = torch.cat([curr_vectors_content, curr_vectors_mid, curr_vectors_style], dim=1)
            gimgs = generator.decode(curr_vectors, random_noise=True)[0]
            image_grid.append(gimgs.detach().cpu())
            
        save_image( torch.cat(image_grid), image_path, nrow=batch+1, value_range=(-1, 1), normalize=True )
        
        try: generator.vqs.train()
        except: generator.gen_256.vqs.train()

def generate_matrix_joint(discriminator, generator, style_encoder, domain, images, image_path, mix_idx=3, cut_idx=6, n_embed=[5,7,9]):
    with torch.no_grad():
        try: generator.vqs.eval()
        except: generator.gen_256.vqs.eval()

        batch = images.shape[0]
        features = discriminator.extract(images)
        style_vectors = style_encoder(features)
        rec_img, _, embed_idxs = generator.decode(style_vectors, branch=domain)
        
        image_grid = [ torch.ones(1, 3, images.shape[2], images.shape[3]) ]
        image_grid.append( images.cpu() )
        image_grid.append( torch.ones(1, 3, images.shape[2], images.shape[3]) )
        image_grid.append(rec_img.detach().cpu())
        
        for eidx in range(len(embed_idxs)):
            seg_map = F.interpolate( idx2image(embed_idxs[eidx], num_colors=n_embed[eidx]) , 256 ) 
            image_grid.append( torch.ones(1, 3, images.shape[2], images.shape[3]) )
            image_grid.append(seg_map.detach().cpu())


        for i in range(len(style_vectors)):
            image_grid.append( rec_img[i].unsqueeze(0).detach().cpu() )
            curr_vectors_tmp = style_vectors[i].unsqueeze(0).repeat(batch, 1, 1)
            curr_vectors_content = curr_vectors_tmp[:,:mix_idx]
            curr_vectors_mid = style_vectors[:,mix_idx:cut_idx]
            curr_vectors_style = curr_vectors_tmp[:,cut_idx:]
            
            curr_vectors = torch.cat([curr_vectors_content, curr_vectors_mid, curr_vectors_style], dim=1)
            gimgs = generator.decode(curr_vectors, random_noise=True, branch=domain)[0]
            image_grid.append(gimgs.detach().cpu())
        save_image( torch.cat(image_grid), image_path, nrow=batch+1, value_range=(-1, 1), normalize=True )
       
        try: generator.vqs.train()
        except: generator.gen_256.vqs.train()



def generate_matrix_grad(discriminator, generator, style_encoder, images, image_path):
    with torch.no_grad():
        try: generator.vqs.eval()
        except: generator.gen_256.vqs.eval()

        features = discriminator.extract(images)
        style_vectors = style_encoder(features)

        image_grid = []
        for i in range(len(style_vectors)):
            for j in range(1, len(style_vectors[0])):
                curr_vector_low = style_vectors[0,:j]
                curr_vector_high = style_vectors[i,j:]
                curr_vector = torch.cat([curr_vector_low, curr_vector_high], dim=0).unsqueeze(0)
            
                gimg = generator.decode(curr_vector, random_noise=False)[0]
                image_grid.append(gimg.detach().cpu())
        save_image( torch.cat(image_grid), image_path, nrow=len(style_vectors[0])-1, range=(-1, 1), normalize=True )
        
        try: generator.vqs.train()
        except: generator.gen_256.vqs.train()

def generate_matrix_ref(discriminator, generator, style_encoder, images_idt, images_pose, image_path, mix_idx=3, dishigh_idx=6):
    with torch.no_grad():
        try: generator.vqs.eval()
        except: generator.gen_256.vqs.eval()

        batch = images_idt.shape[0]
        features = discriminator.extract(images_idt)
        idt_vectors = style_encoder(features)
        rec_img = generator.decode(idt_vectors)[0]
        
        features = discriminator.extract(images_pose)
        pose_vectors = style_encoder(features)

        image_grid = [ torch.ones(1, 3, images_idt.shape[2], images_idt.shape[3]) ]
        image_grid.append( images_pose.cpu() )
        
        for i in range(len(idt_vectors)):
            image_grid.append( rec_img[i].unsqueeze(0).detach().cpu() )
            curr_vectors_tmp = idt_vectors[i].unsqueeze(0).repeat(batch, 1, 1)
            curr_vectors_content = curr_vectors_tmp[:,:mix_idx]
            curr_vectors_mid     = pose_vectors[:,mix_idx:dishigh_idx]
            curr_vectors_style   = curr_vectors_tmp[:,dishigh_idx:]
            curr_vectors = torch.cat([curr_vectors_content, curr_vectors_mid, curr_vectors_style], dim=1)
            
            gimgs = generator.decode(curr_vectors, random_noise=True)[0]
            image_grid.append(gimgs.detach().cpu())

        save_image( torch.cat(image_grid), image_path, nrow=batch+1, range=(-1, 1), normalize=True )

        try: generator.vqs.train()
        except: generator.gen_256.vqs.train()

def generate_matrix_ref_joint(discriminator, generator, style_encoder, domain, images_idt, images_pose, image_path, dislow, dishigh):
    with torch.no_grad():
        try: generator.vqs.eval()
        except: generator.gen_256.vqs.eval()

        batch = images_idt.shape[0]
        features = discriminator.extract(images_idt)
        idt_vectors = style_encoder(features)
        rec_img = generator.decode(idt_vectors, branch=domain)[0]
        
        features = discriminator.extract(images_pose)
        pose_vectors = style_encoder(features)

        image_grid = [ torch.ones(1, 3, images_idt.shape[2], images_idt.shape[3]) ]
        image_grid.append( images_pose.cpu() )

        for i in range(len(idt_vectors)):
            image_grid.append( rec_img[i].unsqueeze(0).detach().cpu() )
            curr_vectors_tmp = idt_vectors[i].unsqueeze(0).repeat(batch, 1, 1)
            curr_vectors_content = curr_vectors_tmp[:,:dislow]
            curr_vectors_mid     = pose_vectors[:,dislow:dishigh]
            curr_vectors_style   = curr_vectors_tmp[:,dishigh:]
            curr_vectors = torch.cat([curr_vectors_content, curr_vectors_mid, curr_vectors_style], dim=1)
            
            gimgs = generator.decode(curr_vectors, random_noise=True, branch=domain)[0]
            image_grid.append(gimgs.detach().cpu())

        save_image( torch.cat(image_grid), image_path, nrow=batch+1, range=(-1, 1), normalize=True )
        
        try: generator.vqs.train()
        except: generator.gen_256.vqs.train()

if __name__=='__main__':

    from models.model_pose_identity import GeneratorPI, DiscriminatorAE, StyleEncoder

    import argparse
    

    parser = argparse.ArgumentParser(description="StyleGAN2 generate")

    parser.add_argument("--path", type=str, default='../../celeba/test', help="path to the image folder",)
    parser.add_argument("--path_b", type=str, default=None, help="path to the image folder",)
    parser.add_argument("--path_c", type=str, default=None, help="path to the image folder",)
    parser.add_argument("--ckpt", type=str, default='checkpoint/000002.pt', help="path to the checkpoints to resume training",)
    parser.add_argument('--name', type=str, default='test1')
    parser.add_argument("--size", type=int, default=256, help="image sizes for the model")
    parser.add_argument("--channel_multiplier", type=int, default=2, help="channel multiplier factor for the model. config-f = 2, else = 1",)
    parser.add_argument("--batch", type=int, default=8, help="image sizes for the model")
    parser.add_argument("--dislow_idx", type=int, default=4, help="conv-layer index, e.g. 8-2 -> 2, 16-1 -> 3",)
    parser.add_argument("--dishigh_idx", type=int, default=7, help="conv-layer index, e.g. 8-2 -> 2, 16-1 -> 3",)
    
    args = parser.parse_args()
    args.latent = 512
    args.n_mlp = 8

    device = torch.device("cuda")
    dislow_idx = args.dislow_idx
    dishigh_idx = args.dishigh_idx

    generator = GeneratorPI(
            args.size, args.latent, args.n_mlp, \
                channel_multiplier=args.channel_multiplier, \
                    dislow=dislow_idx, dishigh=dishigh_idx, n_embed=[2,6,6]
        ).to(device)

    discriminator = DiscriminatorAE(
            args.size, channel_multiplier=args.channel_multiplier
        ).to(device)

    style_encoder = StyleEncoder(args.size, args.latent, channel_multiplier=args.channel_multiplier).to(device)

    ckpt = torch.load(args.ckpt)#, map_location=lambda storage, loc: storage)
    discriminator.load_state_dict(ckpt["d"])
    generator.load_state_dict(ckpt["g_ema"])
    style_encoder.load_state_dict(ckpt['se_ema'])

    discriminator.eval()
    generator.eval()
    style_encoder.eval()
    
    #from pytorch_lightning.utilities.seed import seed_everything
    #seed_everything(777)
    
    transform = transforms.Compose( [
            transforms.ToTensor(),
            transforms.Resize((args.size, args.size)),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
        ] )
    dataset = ImageFolder(args.path, transform=transform)
    loader = iter(data.DataLoader(dataset, batch_size=args.batch, drop_last=True, shuffle=True))
    images_idt = next(loader)[0].to(device)
    images_pose = next(loader)[0].to(device)
    '''
    generate_matrix(discriminator, generator, style_encoder, images_idt, 'test_6_only.jpg', mix_idx=7, cut_idx=10, n_embed=[5,7,9])
    generate_matrix(discriminator, generator, style_encoder, images_idt, 'test_6_without.jpg', mix_idx=4, cut_idx=7, n_embed=[5,7,9])
    generate_matrix(discriminator, generator, style_encoder, images_idt, 'test_6_with.jpg', mix_idx=4, cut_idx=10, n_embed=[5,7,9])

    generate_matrix_grad(discriminator, generator, style_encoder, images_idt, 'test_grad.jpg')
    
    generate_matrix_ref(discriminator, generator, style_encoder, images_idt, images_idt, './test_result_animal.jpg', mix_idx=dislow_idx)
    
    dataset = ImageFolder(args.path_b, transform=transform)
    loader = iter(data.DataLoader(dataset, batch_size=args.batch, drop_last=True, shuffle=True))
    images_pose = next(loader)[0].to(device)
    '''
    generate_matrix_ref(discriminator, generator, style_encoder, images_idt, images_pose, './test_result_face.jpg', mix_idx=dislow_idx)

    dataset = ImageFolder(args.path_c, transform=transform)
    loader = iter(data.DataLoader(dataset, batch_size=args.batch, drop_last=True, shuffle=True))
    images_pose = next(loader)[0].to(device)

    generate_matrix_ref(discriminator, generator, style_encoder, images_idt, images_pose, './test_result_anime.jpg', mix_idx=dislow_idx)
    