import sys
from typing import List

from PyQt5 import QtWidgets
from PyQt5 import QtCore
from PyQt5 import QtGui
from PyQt5.QtWidgets import * 
from PyQt5.QtGui import * 
from PyQt5.QtCore import * 

import torch
from torch.nn import functional as F
from model import IPI2I
from PIL import Image
from torchvision import transforms
from torchvision.utils import save_image
import numpy as np
import os

im_size = 256
transform = transforms.Compose( [
            transforms.ToTensor(),
            transforms.Resize((im_size, im_size)),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
        ] )

segment_index_to_color = {
    0: Qt.red,
    1: Qt.yellow,
    2: Qt.blue,
    3: Qt.green,
    4: Qt.cyan,
    5: Qt.magenta,
}
segment_color_to_index = {
    value: key for key, value in segment_index_to_color.items()
}
segment_index_to_colorname = {
    'Red': Qt.red,
    'Yellow': Qt.yellow,
    'Blue': Qt.blue,
    'Green': Qt.green,
    'Cyan': Qt.cyan,
    'Magenta': Qt.magenta,
}

def image_to_np(x):
    # x is a tensor with shape b x c x h x w
    assert x.shape[0] == 1
    x = x.squeeze(0).permute(1, 2, 0)
    x = (x + 1) * 0.5  # 0-1
    x = (x * 255).cpu().numpy().astype('uint8')
    return x

def np_to_pixmap(np_arr):
    height, width, channel = np_arr.shape
    q_image = QImage(np_arr.data, width, height, 3 * width, QImage.Format_RGB888)
    return QPixmap(q_image)

def tensor_to_pixmap(tensor):
    # tensor has shape b x c x h x w
    save_image(tensor.add(1).mul(0.5), 'tmp_img.jpg')
    image = QPixmap.fromImage(QImage('tmp_img.jpg'))
    os.remove('tmp_img.jpg')
    return image

class SegmentMapGridScene(QtWidgets.QGraphicsScene):

    def __init__(self, emb_idx, grid_width=20, grid_height=20, n_embed=6, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.lines = []

        #print(emb_idx[:4,:4])
        self.parent_window = None
        self.emb_idx = emb_idx
        self.n_embed = n_embed
        self.original_emb_idx = emb_idx.clone() if emb_idx!=None else None
        # creating image object
        self.grid_width = grid_width
        self.grid_height = grid_height
        self.num_grids_x = 16 #emb_idx.shape[1]
        self.num_grids_y = 16 #emb_idx.shape[0]
        self.width = self.grid_width*self.num_grids_x
        self.height = self.grid_height*self.num_grids_y

        self.image = QImage(self.width, self.height, QImage.Format_RGB32)
        self.image.fill(Qt.white)
        self.addPixmap( QPixmap.fromImage( self.image ) )
        
        # create painter instance with image, 
        # and set painter's pen with width and color
        self.painter = QtGui.QPainter(self.image)
        self.pen = QtGui.QPen(QtCore.Qt.red, self.grid_width, Qt.SolidLine)
        self.painter.setPen(self.pen)
        self.pen2idx = segment_color_to_index.get(Qt.red)
        
        #self.draw_segmap_from_tensor(self.emb_idx)

    def reset_segmap(self):
        self.emb_idx = self.original_emb_idx.clone()
        self.draw_segmap_from_tensor(self.emb_idx)

    def draw_grid(self):
        
        self.setSceneRect(0, 0, self.width, self.height)
        self.setItemIndexMethod(QtWidgets.QGraphicsScene.NoIndex)

        pen = QPen(QColor(128,128,128), 1, Qt.SolidLine)

        for x in range(0,self.num_grids_x+1):
            xc = x * self.grid_width
            self.lines.append(self.addLine(xc,0,xc,self.height,pen))

        for y in range(0,self.num_grids_y+1):
            yc = y * self.grid_height
            self.lines.append(self.addLine(0,yc,self.width,yc,pen))

        self.set_opacity(0.3)

    def set_visible(self,visible=True):
        for line in self.lines:
            line.setVisible(visible)

    def delete_grid(self):
        for line in self.lines:
            self.removeItem(line)
        del self.lines[:]

    def set_opacity(self,opacity):
        for line in self.lines:
            line.setOpacity(opacity)

    def draw_segmap_from_tensor(self, tensor):
        # tensor shape: h*w
        self.emb_idx = tensor
        self.original_emb_idx = tensor.clone()
        for y in range(tensor.shape[0]):
            for x in range(tensor.shape[1]):
                color = segment_index_to_color.get(tensor[y][x].item())
                pen = QPen(color, self.grid_width, Qt.SolidLine)
                self.painter.setPen(pen)
                self.pen2idx = segment_color_to_index.get(color)
                xPos, yPos = self.grid_width*x, self.grid_height*y
                #a little positiopn tweak to make it align well
                xPos, yPos = int(self.grid_width*0.5+xPos), \
                                int(self.grid_height*0.5+yPos)
                self.painter.drawLine(xPos,yPos,xPos,yPos)
        
        self.addPixmap( QPixmap.fromImage( self.image ) )
        self.draw_grid()

        color = segment_index_to_color.get(0)
        pen = QPen(color, self.grid_width, Qt.SolidLine)
        self.painter.setPen(pen)
        self.pen2idx = 0

    # method for checking mouse cicks
    def mousePressEvent(self, event):
        # if left mouse button is pressed, 
        # draw the corresponding location to a certain color
        if event.button() == Qt.LeftButton:
            x, y = event.scenePos().x(), event.scenePos().y()

            grid_x, grid_y = x//self.grid_width, y//self.grid_height
            xPos, yPos = self.grid_width*grid_x, self.grid_height*grid_y

            if self.pen2idx > self.n_embed-1:
                QMessageBox.about(self.parent_window, "Warning", "Wrong brush color, out of embedding number")
                return

            if int(grid_y)>=self.emb_idx.shape[0] or int(grid_x)>=self.emb_idx.shape[1]:
                #print("draw out of the area")
                return

            self.emb_idx[int(grid_y)][int(grid_x)] = int(self.pen2idx)
            #print(self.emb_idx[:4,:4])

            #a little positiopn tweak to make it align well
            xPos, yPos = int(self.grid_width*0.5+xPos), \
                            int(self.grid_height*0.5+yPos)
            self.painter.drawLine(xPos,yPos,xPos,yPos)

            self.addPixmap( QPixmap.fromImage( self.image ) )
            self.draw_grid()


class SegmentMapGridView(QtWidgets.QGraphicsView):

    def __init__(self, emb_idx=None, grid_width=20, grid_height=20, n_embed=6, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.grid_scene = SegmentMapGridScene(emb_idx, grid_width, grid_height, n_embed=n_embed)
        self.setScene(self.grid_scene)


# window class
class Window(QMainWindow):

    def _init_segmap_scene(self, n_embeds):
       
        self.segment_map_views = []
        self.segment_map_buttons = []
        self.sv_layout = QVBoxLayout()
        for i in range(3):
            emb_idx = None
            grid_nbr = 16
            grid_width = grid_hight = 12 // (grid_nbr//16)
            segmap_view = SegmentMapGridView( emb_idx, grid_width, grid_hight, n_embed=n_embeds[i] )
            segmap_view.grid_scene.parent_window = self

            segmap_reset_button = QPushButton("Reset", self)
            self.sv_layout.addWidget(segmap_view)
            self.sv_layout.addWidget(segmap_reset_button)
            self.segment_map_views.append( segmap_view )
            self.segment_map_buttons.append( segmap_reset_button )

            segmap_reset_button.clicked.connect(self.resetSegmap(i))

        # adding brush color to ain menu
        mainMenu = self.menuBar()
        brush_color = mainMenu.addMenu("Brush Color", )
        for color_name, color in segment_index_to_colorname.items():
            act = QAction(color_name, self)
            brush_color.addAction(act)
            act.triggered.connect( self.setBrushColor(color) )

    def resetSegmap(self, idx):
        def _reset_segmap():
            self.segment_map_views[idx].grid_scene.reset_segmap()
        return _reset_segmap

    def setBrushColor(self, color):
        def _set_color():   
            for sv in self.segment_map_views:
                pen = QPen(color, sv.grid_scene.grid_width, Qt.SolidLine)
                sv.grid_scene.painter.setPen(pen)
                sv.grid_scene.pen2idx = segment_color_to_index.get(color)
        return _set_color
 

    def __init__(self, ckpt):
        super().__init__()

        # setting title, geometry to main window
        self.setWindowTitle("Pose editor")
        self.setGeometry(100, 100, 1080, 800)
        
        # set the layout of multiple views
        central_widget = QWidget()
        main_layout = QGridLayout(central_widget)

        

        # load the main network
        print('loading the neural network ...')
        self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
        #self.net = IPI2I(dislow=4, dishigh=7, n_embed=[2,6,6]).eval().to(self.device)
        #self.net.load_state_dict(torch.load('./anime_pi.pt'))
        self.net = IPI2I(   dislow=ckpt.get('args').dislow, \
                            dishigh=ckpt.get('args').dishigh, \
                            n_embed=ckpt.get('args').vq_emb, \
                            vq = ckpt.get('args').vq_type
                            ).eval().to(self.device)
        self.net.load_state_dict( ckpt['net'] )
        print('done ...')
        
        # create embedding-mask scene
        self._init_segmap_scene(ckpt.get('args').vq_emb)

        self.idt_latents = None
        self.pose_latents = None

        # create image view
        img_layout = QHBoxLayout()
        self.img_width = self.img_height = 256
        
        img_layout_1 = QVBoxLayout()
        self.identity_image_label = QLabel(self)
        identity_image = QImage(self.img_width, self.img_height, QImage.Format_RGB32)
        identity_image.fill(Qt.yellow)
        self.identity_image = QPixmap.fromImage(identity_image)
        self.identity_image_label.setPixmap(self.identity_image)
        img_layout_1.addWidget(QPushButton("Identity image"))
        img_layout_1.addWidget(self.identity_image_label)

        img_layout_2 = QVBoxLayout()
        self.pose_image_label = QLabel(self)
        pose_image = QImage(self.img_width, self.img_height, QImage.Format_RGB32)
        pose_image.fill(Qt.green)
        self.pose_image = QPixmap.fromImage(pose_image)
        self.pose_image_label.setPixmap(self.pose_image)
        img_layout_2.addWidget(QPushButton("Pose image"))
        img_layout_2.addWidget(self.pose_image_label)

        img_layout_3 = QVBoxLayout()
        self.editted_image_label = QLabel(self)
        edit_image = QImage(self.img_width, self.img_height, QImage.Format_RGB32)
        edit_image.fill(Qt.gray)
        self.edit_image = QPixmap.fromImage(edit_image)
        self.editted_image_label.setPixmap(self.edit_image)
        img_layout_3.addWidget(QPushButton("generated image"))
        img_layout_3.addWidget(self.editted_image_label)

        img_layout.addLayout(img_layout_1)
        img_layout.addLayout(img_layout_2)
        img_layout.addLayout(img_layout_3)

        # build buttons
        btn_layout = QHBoxLayout()
        self.load_idt_image_button = QPushButton("Load Identity Image", self)
        btn_layout.addWidget(self.load_idt_image_button)
        self.load_idt_image_button.clicked.connect( self.load_identity_image )

        self.load_pose_image_button = QPushButton("Load Pose Image", self)
        btn_layout.addWidget(self.load_pose_image_button)
        self.load_pose_image_button.clicked.connect( self.load_pose_image )

        self.generate_image_button = QPushButton("Generate Image", self)
        btn_layout.addWidget(self.generate_image_button)
        self.generate_image_button.clicked.connect( self.generate_image )

        self.save_image_button = QPushButton("Save Image", self)
        btn_layout.addWidget(self.save_image_button)
        self.save_image_button.clicked.connect( self.save_image )

        # combine all the layouts
        main_layout.addLayout(self.sv_layout, 0, 0, 3, 1)
        main_layout.addLayout(img_layout, 0, 1, 2, 3)
        main_layout.addLayout(btn_layout, 2, 1)
        self.setCentralWidget(central_widget)


    # method for saving canvas
    def save_image(self):
        filePath, _ = QFileDialog.getSaveFileName(self, "Save Image", "",
                          "PNG/JPG(*.png *.jpg *.jpeg);;All Files(*.*) ")
  
        if filePath == "":
            return
        self.edit_image.save(filePath)

    def load_identity_image(self):
        filePath, _ = QFileDialog.getOpenFileName(self, "Open Image", "",
                          "PNG/JPG(*.png *.jpg *.jpeg);;All Files(*.*) ")
        if filePath == "":
            return
        identity_image = QPixmap.fromImage(QImage(filePath))
        self.identity_image = identity_image.scaledToWidth(self.img_width).scaledToHeight(self.img_height)
        self.identity_image_label.setPixmap(self.identity_image)
        self.update()

        image = transform( Image.open(filePath) ).view(1,3,im_size,im_size).to(self.device)
        
        if self.pose_latents == None:
            rec_image, latents, embed_idxs = self.net.get_latents_and_rec_image(image)
            self.idt_latents = latents

            edit_image = tensor_to_pixmap(rec_image)
            self.edit_image = edit_image.scaledToWidth(self.img_width).scaledToHeight(self.img_height)
            self.editted_image_label.setPixmap(self.edit_image)
            self.update()

            for ei, embed in enumerate(embed_idxs):
                self.segment_map_views[ei].grid_scene.draw_segmap_from_tensor(embed[0])
        else:
            self.idt_latents = self.net.get_latents(image)
            self.generate_image()

    def load_pose_image(self):
        if self.idt_latents==None:
            QMessageBox.about(self, "Warning", "Please first select an identity image.")
            return

        filePath, _ = QFileDialog.getOpenFileName(self, "Open Image", "",
                          "PNG/JPG(*.png *.jpg *.jpeg);;All Files(*.*) ")
        if filePath == "":
            return
        pose_image = QPixmap.fromImage(QImage(filePath))
        self.pose_image = pose_image.scaledToWidth(self.img_width).scaledToHeight(self.img_height)
        self.pose_image_label.setPixmap(self.pose_image)
        self.update()

        image = transform( Image.open(filePath) ).view(1,3,im_size,im_size).to(self.device)
        self.pose_latents = self.net.get_latents(image)
        rec_image, embed_idxs = self.net.forward_with_mix_latents(self.idt_latents, self.pose_latents)
        
        edit_image = tensor_to_pixmap(rec_image)
        self.edit_image = edit_image.scaledToWidth(self.img_width).scaledToHeight(self.img_height)
        
        self.editted_image_label.setPixmap(self.edit_image)
        self.update()

        for ei, embed in enumerate(embed_idxs):
            self.segment_map_views[ei].grid_scene.draw_segmap_from_tensor(embed[0])

    def generate_image(self):
        emb_idxs = []
        for ei in range(3):
            emb_idxs.append( self.segment_map_views[ei].grid_scene.emb_idx.detach().clone().unsqueeze(0) )

        g_image = self.net.forward_with_segmap(self.idt_latents, emb_idxs)
        edit_image = tensor_to_pixmap(g_image)
        self.edit_image = edit_image.scaledToWidth(self.img_width).scaledToHeight(self.img_height)
        
        self.editted_image_label.setPixmap(self.edit_image)
        self.update()

if __name__ == '__main__':

    import argparse

    parser = argparse.ArgumentParser(description="demo app for images identity/pose transfer")
    
    parser.add_argument("--ckpt", type=str, default=None, help="path to the checkpoints of the trained model")

    args = parser.parse_args()

    app = QtWidgets.QApplication(sys.argv)
    
    ckpt = torch.load(args.ckpt)
    window = Window(ckpt)
    window.show()
    sys.exit(app.exec_())