from __future__ import print_function
import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image
import torchvision.transforms.functional as F_

import math
import random
import pandas as pd
import cv2
import scipy.ndimage as ndimage
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float


class IBUG_300W(torch.utils.data.Dataset):
    def __init__(self, data_path, is_train, random_scale, random_flip, random_rotation):
        self.is_train = is_train
        self.data_root = data_path
        self.scale_factor = 0.05
        self.rot_factor = 0
        self.random_scale = random_scale
        self.random_flip = random_flip
        self.random_rotation = random_rotation
        
        if is_train:
            self.df = pd.read_csv(data_path + "300W/300W_train_data.csv")
        else:
            self.df = pd.read_csv(data_path + "300W/300W_test_data_full.csv")
        
        self.mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
        self.std = np.array([0.229, 0.224, 0.225], dtype=np.float32)

        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        row = self.df.iloc[index]
        img_path = os.path.join(self.data_root, row['image_path'])
        img = np.array(Image.open(img_path).convert('RGB'), dtype=np.float32)
        
        # 300W dataset : 58 landmarks. [[x0, y0], [x1, y1], ..., [x67, y67]] order.
        pts = np.array(row[1:]).reshape(-1, 2)
        
        xmin = np.min(pts[:, 0])
        xmax = np.max(pts[:, 0])
        ymin = np.min(pts[:, 1])
        ymax = np.max(pts[:, 1])
        center_w = (math.floor(xmin) + math.ceil(xmax)) / 2.0
        center_h = (math.floor(ymin) + math.ceil(ymax)) / 2.0
        center = torch.Tensor([center_w, center_h])   # x, y order
        
        scale = max(math.ceil(xmax) - math.floor(xmin), math.ceil(ymax) - math.floor(ymin)) / 200.0
        scale *= 1.25
        
        if self.random_scale:
            scale = scale * (random.uniform(1 - self.scale_factor, 1 + self.scale_factor))
        
        r = 0
        if self.random_rotation : 
            r = random.uniform(-self.rot_factor, self.rot_factor) if random.random() <= 0.6 else 0
        
        if self.random_flip: 
            if random.random() <= 0.5:
                img = np.fliplr(img)
                pts = self.fliplr_joints(pts, width=img.shape[1])
                center[0] = img.shape[1] - center[0]
        
        img, scale_factor = self.crop(img, center, scale, [256, 256], rot=r)

        tpts = pts.copy()
        for i in range(pts.shape[0]):
            if tpts[i, 1] > 0 :
                tpts[i, 0:2] = self.transform_pixel(tpts[i, 0:2]+1, center, scale*scale_factor, [256,256], rot=r)
        
        img = (img - self.mean) / self.std
        img = img.transpose([2, 0, 1])
        img = torch.Tensor(img)
        
        # for [y1, x1, y2, x2, ...] order
        tpts = np.fliplr(tpts).flatten().astype(np.float32)
        tpts = torch.LongTensor(tpts)
        return img, tpts


    def fliplr_joints(self, landmark_coordinates, width):
        matched_parts = [[1, 17], [2, 16], [3, 15], [4, 14], [5, 13], [6, 12], [7, 11], [8, 10],
                         [18, 27], [19, 26], [20, 25], [21, 24], [22, 23],
                         [32, 36], [33, 35],
                         [37, 46], [38, 45], [39, 44], [40, 43], [41, 48], [42, 47],
                         [49, 55], [50, 54], [51, 53], [62, 64], [61, 65], [68, 66], [59, 57], [60, 56]]
        
        landmark_coordinates[:, 0] = width - landmark_coordinates[:, 0]
        for pair in matched_parts:
            tmp = landmark_coordinates[pair[0] - 1, :].copy()
            landmark_coordinates[pair[0] - 1, :] = landmark_coordinates[pair[1] - 1, :]
            landmark_coordinates[pair[1] - 1, :] = tmp
        return landmark_coordinates


    def get_transform(self, center, scale, output_size, rot=0):
        h = 200 * scale
        t = np.zeros((3, 3))
        t[0, 0] = float(output_size[1]) / h
        t[1, 1] = float(output_size[0]) / h
        t[0, 2] = output_size[1] * (-float(center[0]) / h + .5)
        t[1, 2] = output_size[0] * (-float(center[1]) / h + .5)
        t[2, 2] = 1
        
        if not rot == 0:
            rot = -rot
            rot_mat = np.zeros((3, 3))
            rot_rad = rot * np.pi / 180
            sn, cs = np.sin(rot_rad), np.cos(rot_rad)
            rot_mat[0, :2] = [cs, -sn]
            rot_mat[1, :2] = [sn, cs]
            rot_mat[2, 2] = 1
            t_mat = np.eye(3)
            t_mat[0, 2] = -output_size[1]/2
            t_mat[1, 2] = -output_size[0]/2
            t_inv = t_mat.copy()
            t_inv[:2, 2] *= -1
            t = np.dot(t_inv, np.dot(rot_mat, np.dot(t_mat, t)))
        return t


    def transform_pixel(self, pt, center, scale, output_size, invert=0, rot=0):
        t = self.get_transform(center, scale, output_size, rot=rot)
        if invert:
            t = np.linalg.inv(t)
        new_pt = np.array([pt[0] - 1, pt[1] - 1, 1.]).T
        new_pt = np.dot(t, new_pt)
        return new_pt[:2].astype(int) + 1


    def crop(self, img, center, scale, output_size, rot=0):
        center_new = center.clone()
        scale_adjustment = 1.
        
        ul = np.array(self.transform_pixel([0, 0], center_new, scale, output_size, invert=1))
        br = np.array(self.transform_pixel(output_size, center_new, scale, output_size, invert=1))
        original_size = (br-ul)[0]
        
        if not rot == 0:
            pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2)
            ul -= pad
            br += pad
        
        new_shape = [br[1] - ul[1], br[0] - ul[0]]
        if len(img.shape) > 2:
            new_shape += [img.shape[2]]
        new_img = np.zeros(new_shape, dtype=np.float32)
        new_x = max(0, -ul[0]), min(br[0], len(img[0])) - ul[0]
        new_y = max(0, -ul[1]), min(br[1], len(img)) - ul[1]
        old_x = max(0, ul[0]), min(len(img[0]), br[0])
        old_y = max(0, ul[1]), min(len(img), br[1])
        new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1], old_x[0]:old_x[1]] / 255
        
        if not rot == 0:
            new_img = ndimage.rotate(new_img, rot)
            new_img = new_img[pad:-pad, pad:-pad]
            scale_adjustment = new_img.shape[0] / original_size
        
        new_img = cv2.resize(new_img.astype(np.float32), output_size)
        return new_img, scale_adjustment



def load_data(task, batch_size, random_scale, random_flip, random_rotation):
    path = '../../dataset/'
    train_set = IBUG_300W(path, True, random_scale, random_flip, random_rotation)
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=True)
    test_set = IBUG_300W(path, False, random_scale=False, random_flip=False, random_rotation=False)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=4, drop_last=True)
    return train_loader, test_loader



def load_state() : 
    path = "../../checkpoint/"
    state = torch.load(path+"300W_state.pth")
    leye_state = {
        'featnet' : state['featnet_leye'],
        'coordnet' : state['coordnet'],
        'prior_z_ft' : state['leye_z_ft'],
        'prior_z_cd' : state['leye_z_cd'] 
        }
    
    reye_state = {
        'featnet' : state['featnet_reye'],
        'coordnet' : state['coordnet'],
        'prior_z_ft' : state['reye_z_ft'],
        'prior_z_cd' : state['reye_z_cd'] 
        }
    
    mouth_state = {
        'featnet' : state['featnet_mouth'],
        'coordnet' : state['coordnet'],
        'prior_z_ft' : state['mouth_z_ft'],
        'prior_z_cd' : state['mouth_z_cd'] 
        }
    
    nose_state = {
        'featnet' : state['featnet_nose'],
        'coordnet' : state['coordnet'],
        'prior_z_ft' : state['nose_z_ft'],
        'prior_z_cd' : state['nose_z_cd'] 
        }
    
    jaw_state = {
        'featnet' : state['featnet_jaw'],
        'coordnet' : state['coordnet'],
        'prior_z_ft' : state['jaw_z_ft'],
        'prior_z_cd' : state['jaw_z_cd'] 
        }
    
    return leye_state, reye_state, mouth_state, nose_state, jaw_state



def make_batch(env, agent, n_landmarks) : 
    with torch.no_grad() : 
        o0 = env.current_o
        l_idx = torch.randint(0, n_landmarks, (o0.size(0),)).to(device)
        
        embedding_ft0, embedding_cd0, prior_ft, prior_cd, lambda_ft, direction\
            = agent.make_training_sample(env, o0, l_idx)
       
    return embedding_ft0, embedding_cd0, prior_ft, prior_cd, lambda_ft, direction




def list_mean(hist) : 
    return sum(hist) / len(hist)



def abs_coord_to_norm(c, img_size=[256,256]):
    return (2 * c / (torch.FloatTensor(img_size)-1).to(device)) - 1



def norm_coord_to_abs(c, img_size=[256,256]) : 
    return torch.round((c + 1) * ((torch.FloatTensor(img_size)-1).to(device) / 2))



def save_model(names, dirpolnet_leye, dirpolnet_reye, dirpolnet_mouth,
               dirpolnet_nose, dirpolnet_jaw, epoch):
    state = {
        'dirpolnet_leye': dirpolnet_leye.state_dict(),
        'dirpolnet_reye': dirpolnet_reye.state_dict(),
        'dirpolnet_mouth': dirpolnet_mouth.state_dict(),
        'dirpolnet_nose': dirpolnet_nose.state_dict(),
        'dirpolnet_jaw': dirpolnet_jaw.state_dict(),
        'epoch': epoch,
        }
    
    torch.save(state, '../../checkpoint/{}_{}epoch.pth'.format(names, epoch))

