import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data
from torchvision.datasets import ImageFolder
from torchvision import transforms
from torchvision.utils import save_image, make_grid

from models.model_pose_identity import DiscriminatorAE, GeneratorPI, StyleEncoder
from utilstrain.utils import idx2image

import math, os, cv2
from PIL import Image

class IPI2I(nn.Module):
    def __init__(self, size=256, latent=512, n_mlp=8, channel_multiplier=2, dislow=3, dishigh=6, n_embed=[2,6,6], vq=0):
        super().__init__()

        self.size = size
        self.channel_multiplier = channel_multiplier
        self.dislow = dislow
        self.dishigh = dishigh

        self.extract_res = [2**p for p in range( int(math.log2(self.size)) , 1, -1)]

        self.generator = GeneratorPI(size, latent, n_mlp, channel_multiplier, dislow, dishigh, n_embed, vq=vq)
        self.generator.vqs.eval()
        self._generate = self.generator.decode

        self.convs = DiscriminatorAE(size, channel_multiplier).convs
        
        self.style_encoder = StyleEncoder(size, latent, channel_multiplier)

    def load_from_training_models(self, ckpt):
        if isinstance(ckpt, str):
            ckpt = torch.load(ckpt)

        self.generator.load_state_dict(ckpt["g_ema"])
        self.style_encoder.load_state_dict(ckpt["se_ema"])

        tmp_discriminator = DiscriminatorAE(self.size, self.channel_multiplier)
        tmp_discriminator.load_state_dict(ckpt['d'])
        self.convs = tmp_discriminator.convs
        del tmp_discriminator
    
    @torch.no_grad()
    def _feat_extract(self, image, feature_res=None):
        if feature_res==None:
            feature_res = self.extract_res
        out = []
        feat = image
        for i in range(len(self.convs)):
            feat = self.convs[i](feat)
            if feat.shape[-1] in feature_res:
                out.append(feat)
                if feat.shape[-1]==feature_res[-1]:
                    break
        return out
    
    @torch.no_grad()
    def forward(self, identity, pose, mix_low=None, mix_high=None):
        if mix_low==None:
            mix_low = self.dislow
        if mix_high==None:
            mix_high = self.dishigh

        feat_idt = self._feat_extract(identity)
        feat_pos = self._feat_extract(pose)

        latents_idt = self.style_encoder(feat_idt)
        latents_pos = self.style_encoder(feat_pos)

        latents_mix = latents_idt
        latents_mix[:,mix_low:mix_high] = latents_pos[:,mix_low:mix_high]
        rec_img, _, embed_idxs = self._generate(latents_mix)

        return rec_img

    @torch.no_grad()
    def forward_with_segmap(self, latents, segmaps):
        return self.generator.decode_with_embed(latents, segmaps)

    @torch.no_grad()
    def get_latents_and_rec_image(self, image):
        feats = self._feat_extract(image)
        latents = self.style_encoder(feats)
        rec_image, _, embed_idxs = self._generate(latents)
        return rec_image, latents, embed_idxs
    
    @torch.no_grad()
    def get_latents(self, image):
        feats = self._feat_extract(image)
        latents = self.style_encoder(feats)
        return latents

    @torch.no_grad()
    def forward_with_latents(self, latents):
        rec_image, _, embed_idxs = self._generate(latents)
        return rec_image, embed_idxs

    @torch.no_grad()
    def forward_with_mix_latents(self, latents_idt, latents_pose):
        latents = latents_idt.clone()
        latents[:,self.dislow:self.dishigh] = latents_pose[:,self.dislow:self.dishigh] 
        rec_image, _, embed_idxs = self._generate(latents)
        return rec_image, embed_idxs


def to_image(tensor):
    # tensor shape: b x c x h x w
    grid = make_grid(tensor.add(1).mul(0.5))
    ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
    return Image.fromarray(ndarr)



def make_gif_from_gif(root, images_idt):
    imageObject = Image.open(root)
    result_image = []
    for frame in range(0, imageObject.n_frames):

        imageObject.seek(frame)
        image_pose = transform( imageObject.convert('RGB') ).view(1,3,size,size).to(device)
        mixed_images = net(images_idt, image_pose.repeat(images_idt.shape[0], 1,1,1))

        cur_result = to_image( torch.cat([image_pose, mixed_images]) )
        result_image.append( cur_result )

    result_image[0].save( os.path.basename(root).replace('.gif','_mix.gif'),
                save_all=True, append_images=result_image[1:], optimize=True, duration=200, loop=0)


def make_gif_from_video(root, images_idt):
    cap = cv2.VideoCapture(root)
    num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    result_image = []
    for i in range(num_frames):
        cap.set(cv2.CAP_PROP_POS_FRAMES, i)
        _, frame = cap.read()
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        image_pose = transform(frame).unsqueeze(0).to(device)

        gimage = net( images_idt, image_pose.repeat(images_idt.shape[0], 1, 1, 1) )
        cur_result = to_image( torch.cat([image_pose, gimage]) )
        result_image.append(cur_result)
    result_image[0].save( os.path.basename(root).replace('.mp4','_mix.gif'),
                save_all=True, append_images=result_image[1:], optimize=True, duration=200, loop=0)


'''
dislow_idx = 3
dishigh_idx = 6
n_embed = [2,6,6]
device = torch.device('cuda')
generator = GeneratorPI(256, 512, 8, channel_multiplier=2, dislow=dislow_idx, dishigh=dishigh_idx, n_embed=n_embed).to(device)
discriminator = DiscriminatorAE(256, channel_multiplier=2).to(device)
style_encoder = StyleEncoder(256, 512, channel_multiplier=2).to(device)

ckpt = torch.load('./checkpoint/000003.pt')
generator.load_state_dict(ckpt['g_ema'])
style_encoder.load_state_dict(ckpt['se_ema'])
discriminator.load_state_dict(ckpt['d'])
'''
def edit_segmap(pose_image, idt_images, discriminator, style_encoder, generator, dislow_idx, dishigh_idx, n_embed):
    mod_level = 1
    #pose_image = images[0].unsqueeze(0)
    #idt_images = images[1:]
    with torch.no_grad():
        pose_feature = discriminator.extract(pose_image)
        pose_vectors = style_encoder(pose_feature).repeat(batch-1, 1, 1)
        
        idt_features = discriminator.extract(idt_images)
        idt_vectors = style_encoder(idt_features)

        mix_vectors = idt_vectors.clone()
        mix_vectors[:,dislow_idx:dishigh_idx] = pose_vectors[:,dislow_idx:dishigh_idx]
        mixed_imgs, _, pose_embed_idxs = generator.decode(mix_vectors)

        seg_map = F.interpolate( idx2image(pose_embed_idxs[mod_level], num_colors=n_embed[mod_level]) , 256 ) 
        image_grid = [ pose_image.cpu(), idt_images.cpu() ]
        image_grid.append(seg_map[0].unsqueeze(0))
        image_grid.append(mixed_imgs.cpu())

        #---------- mod 1 ----------
        # some random segmap modification
        mod_1_pose_idxs = [ pei.clone() for pei in pose_embed_idxs]
        mod_1_pose_idxs[mod_level][:, 9:13, 6:10] = 2

        mod_1_imgs = generator.decode_with_embed(mix_vectors, mod_1_pose_idxs)
        
        seg_map = F.interpolate( idx2image(mod_1_pose_idxs[mod_level], num_colors=n_embed[mod_level]) , 256 ) 
        image_grid.append(seg_map[0].unsqueeze(0))
        image_grid.append(mod_1_imgs.cpu())

        #---------- mod 2 ----------
        mod_2_pose_idxs = [ pei.clone() for pei in pose_embed_idxs]
        mod_2_pose_idxs[mod_level][:, 8:14, 6:11] = 2
        mod_2_imgs = generator.decode_with_embed(mix_vectors, mod_2_pose_idxs)
        
        seg_map = F.interpolate( idx2image(mod_2_pose_idxs[mod_level], num_colors=n_embed[mod_level]) , 256 ) 
        image_grid.append(seg_map[0].unsqueeze(0))
        image_grid.append(mod_2_imgs.cpu())

        save_image(torch.cat(image_grid), 'test_edit.jpg', normalize=True, value_range=(-1,1), nrow=batch)

#edit_segmap(images_idt[0].unsqueeze(0), images_idt[1:])


if __name__=='__main__':

    import argparse
    
    parser = argparse.ArgumentParser(description="StyleGAN2 generate")

    parser.add_argument("--path", type=str, default='../../celeba/test', help="path to the image folder",)
    parser.add_argument("--ckpt", type=str, help="path to the model",)
    parser.add_argument("--pose", type=str, help="path to the pose gif image or mp4 video",)

    args = parser.parse_args()
    
    device = torch.device('cuda')

    ckpt = torch.load( args.ckpt )
    net = IPI2I( dislow=ckpt['args'].dislow, dishigh=ckpt['args'].dishigh, n_embed=ckpt['args'].vq_emb, vq=0 )
    net.load_from_training_models(ckpt)
    net.to(device)
    
    size = 256
    path = args.path
    batch = 4

    transform = transforms.Compose( [
                transforms.ToTensor(),
                transforms.Resize((size, size)),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
            ] )
    dataset = ImageFolder(path, transform=transform)
    loader = iter(data.DataLoader(dataset, batch_size=batch, drop_last=True, shuffle=True))
    images_idt = next(loader)[0].to(device)

    root = args.pose

    if '.mp4' in root:
        make_gif_from_video(root, images_idt)
    elif '.gif' in root:
        make_gif_from_gif(root, images_idt)