import torch
from torch.utils.data import Dataset, DataLoader
import os
import sys
import cv2
import numpy as np
# import cupy as cp
import warnings
# import chainer
from PIL import Image
import math
warnings.filterwarnings("ignore")
sys.path.append("./neural_renderer/")
import matplotlib.pyplot as plt
import utils.nmr_test as nmr
import neural_renderer
from torchvision import transforms
from torchvision.transforms import functional as F

class MyDatasetTestAdv(Dataset):
    def __init__(self, data_dir, img_size, texture_size, faces, vertices, distence=None, mask_dir='', ret_mask=False):
        self.data_dir = data_dir
        self.files = []
        files = os.listdir(data_dir)
        files.sort(key=lambda x: int(x[4:-4]))
        for file in files:
            if distence is None:
                self.files.append(file)
            else:
                data = np.load(os.path.join(self.data_dir, file))
                veh_trans = data['veh_trans']
                cam_trans = data['cam_trans']
                dis = (cam_trans - veh_trans)[0, :]
                dis = np.sum(dis ** 2)
                # print(dis)
                if dis <= distence:
                    self.files.append(file)

        self.img_size = img_size
        textures = np.ones((1, faces.shape[0], texture_size, texture_size, texture_size, 3), 'float32')
        self.textures_adv = torch.from_numpy(textures).cuda(device=0)
        self.faces_var = faces[None, :, :]
        self.vertices_var = vertices[None, :, :]
        self.mask_renderer = nmr.NeuralRenderer(img_size=img_size).cuda()
        self.mask_dir = mask_dir
        self.ret_mask = ret_mask

    def set_textures(self, textures_adv):
        self.textures_adv = textures_adv

    def __getitem__(self, index):
        file = os.path.join(self.data_dir, self.files[index])
        data = np.load(file, allow_pickle=True)  #.item()
        img = data['img']
        veh_trans, cam_trans = data['veh_trans'], data['cam_trans']

        eye, camera_direction, camera_up = nmr.get_params(cam_trans, veh_trans)
        self.mask_renderer.renderer.renderer.eye = eye
        self.mask_renderer.renderer.renderer.camera_direction = camera_direction
        self.mask_renderer.renderer.renderer.camera_up = camera_up
        self.mask_renderer.renderer.renderer.background_color = [0.6, 0.6, 0.6]

        imgs_pred = self.mask_renderer.forward(self.vertices_var, self.faces_var, self.textures_adv)

        img = img[:, :, ::-1] 
        img = cv2.resize(img, (self.img_size, self.img_size))
        img = np.transpose(img, (2, 0, 1))
        img = np.resize(img, (1, img.shape[0], img.shape[1], img.shape[2]))
        img = torch.from_numpy(img).cuda(device=0)

        imgs_pred = imgs_pred / torch.max(imgs_pred)

        if self.ret_mask:
            mask_file = os.path.join(self.mask_dir, "%s.png" % self.files[index][:-4])
            mask = cv2.imread(mask_file)
            mask = cv2.resize(mask, (self.img_size, self.img_size))
            mask = np.logical_or(mask[:, :, 0], mask[:, :, 1], mask[:, :, 2])
            mask = torch.from_numpy(mask.astype('float32')).cuda()
            total_img = (1 - mask) * img + (255 * imgs_pred) * mask
            return index, total_img.squeeze(0), imgs_pred.squeeze(0), mask, self.files[index]

        total_img = img + 255 * imgs_pred
        return index, total_img.squeeze(0), imgs_pred.squeeze(0), self.files[index]

    def __len__(self):
        return len(self.files)

class MyDatasetTestAdv_Shift(Dataset):
    def __init__(self, data_dir, img_size, texture_size, faces, vertices, distence=None, mask_dir='', ret_mask=False):
        self.data_dir = data_dir
        self.files = []
        files = os.listdir(data_dir)
        files.sort(key=lambda x: int(x[4:-4]))
        for file in files:
            if distence is None:
                self.files.append(file)
            else:
                data = np.load(os.path.join(self.data_dir, file))
                veh_trans = data['veh_trans']
                cam_trans = data['cam_trans']
                dis = (cam_trans - veh_trans)[0, :]
                dis = np.sum(dis ** 2)
                # print(dis)
                if dis <= distence:
                    self.files.append(file)
        print(len(self.files))
        self.img_size = img_size
        textures = np.ones((1, faces.shape[0], texture_size, texture_size, texture_size, 3), 'float32')
        self.textures_adv = torch.from_numpy(textures).cuda(device=0)
        self.faces_var = faces[None, :, :]
        self.vertices_var = vertices[None, :, :]
        self.mask_renderer = nmr.NeuralRenderer(img_size=img_size).cuda() 
        self.mask_dir = mask_dir
        self.ret_mask = ret_mask

    def set_textures(self, textures_adv):
        self.textures_adv = textures_adv

    def __getitem__(self, index):
        file = os.path.join(self.data_dir, self.files[index])
        data = np.load(file, allow_pickle=True)  #.item()
        img = data['img']
        veh_trans, cam_trans = data['veh_trans'], data['cam_trans']

        eye, camera_direction, camera_up = nmr.get_params(cam_trans, veh_trans)
        self.mask_renderer.renderer.renderer.eye = eye
        self.mask_renderer.renderer.renderer.camera_direction = camera_direction
        self.mask_renderer.renderer.renderer.camera_up = camera_up
        self.mask_renderer.renderer.renderer.background_color = [0.4, 0.4, 0.4]

        imgs_pred = self.mask_renderer.forward(self.vertices_var, self.faces_var, self.textures_adv)

        img = img[:, :, ::-1] 
        img = cv2.resize(img, (self.img_size, self.img_size), interpolation=cv2.INTER_LINEAR)
        img = np.transpose(img, (2, 0, 1))
        img = np.resize(img, (1, img.shape[0], img.shape[1], img.shape[2]))
        img = torch.from_numpy(img).cuda(device=0)

        imgs_pred = imgs_pred / torch.max(imgs_pred) 

        if self.ret_mask:
            mask_file = os.path.join(self.mask_dir, "%s.png" % self.files[index][:-4])
            mask = cv2.imread(mask_file) 
            mask = cv2.resize(mask, (self.img_size, self.img_size), interpolation=cv2.INTER_LINEAR) 
            gray2 = cv2.cvtColor(mask,cv2.COLOR_BGR2GRAY)  
            _, binary2 = cv2.threshold(gray2, 1, 255, cv2.THRESH_BINARY)  
            contours2, _ = cv2.findContours(binary2, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)  
            rescount2, rescount2 = [], []
            for c in contours2:
                if len(c) > len(rescount2): rescount2 = c
            x2,y2,w2,h2 = cv2.boundingRect(rescount2) 
            mask = cv2.GaussianBlur(gray2, (3, 3), 0) / 255
            mask = torch.from_numpy(mask.astype('float32')).cuda()


            gray = cv2.cvtColor((255 * imgs_pred).data.cpu().numpy()[0].transpose((1, 2, 0)).astype('uint8'), cv2.COLOR_BGR2GRAY)  
            gray[gray==gray[0,0]] = 0
            _, binary = cv2.threshold(gray, 1, 255, cv2.THRESH_BINARY)  
            contours, _ = cv2.findContours(binary, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)  
            rescount = []
            for c in contours:
                if len(c) > len(rescount): rescount = c
            x1,y1,w1,h1 = cv2.boundingRect(rescount)  
            imgs_pred_shift = torch.ones_like(imgs_pred) * 0.4
            imgs_pred_shift[:, :, y2:y2+h2, x2:x2+w2] = torch.nn.functional.interpolate(
                    imgs_pred[:, :, y1: y1+h1, x1:x1+w1], size = (h2, w2), mode='bilinear', align_corners=True)

            total_img = (1 - mask) * img + (255 * imgs_pred_shift) * mask
            return index, total_img.squeeze(0), imgs_pred_shift.squeeze(0), mask, self.files[index]
        
        total_img = img + 255 * imgs_pred
        return index, total_img.squeeze(0), imgs_pred.squeeze(0), self.files[index]

    def __len__(self):
       return len(self.files)
