from __future__ import print_function
import torch
from scipy.integrate import simps
import math

import random
import numpy as np
import h5py
import cv2
import scipy.ndimage as ndimage
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


class COFW(torch.utils.data.Dataset):
    def __init__(self, data_path, is_train, random_scale, random_flip, random_rotation):
        self.is_train = is_train
        self.data_root = data_path
        self.scale_factor = 0.05
        self.rot_factor = 0
        self.random_scale = random_scale
        self.random_flip = random_flip
        self.random_rotation = random_rotation
        
        if is_train:
            self.f = h5py.File(data_path+"COFW_train.mat", 'r')
            self.images = self.f['IsTr']
            self.images = self.images[0]
            self.pts = self.f['phisTr']
        else:
            self.f = h5py.File(data_path+"COFW_test.mat", 'r')
            self.images = self.f['IsT']
            self.images = self.images[0]
            self.pts = self.f['phisT']
            
        self.mean = np.array([0.4637], dtype=np.float32)
        self.std = np.array([0.2591], dtype=np.float32)
        
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, index):
        image_ref = self.images[index]
        img = self.f[image_ref]
        img = np.array(img).T
        # COFW dataset : 29 landmarks. [[x0,y0], [x1,y1], ..., [x28,y28]] order.
        pts = np.transpose(self.pts)[index][0:58].reshape(2, -1).transpose()
        
        if len(img.shape) == 2:
            img = img.reshape(img.shape[0], img.shape[1], 1)
        
        xmin = np.min(pts[:, 0])
        xmax = np.max(pts[:, 0])
        ymin = np.min(pts[:, 1])
        ymax = np.max(pts[:, 1])
        center_w = (math.floor(xmin) + math.ceil(xmax)) / 2.0
        center_h = (math.floor(ymin) + math.ceil(ymax)) / 2.0
        center = torch.Tensor([center_w, center_h])   # x, y order
        scale = max(math.ceil(xmax) - math.floor(xmin), math.ceil(ymax) - math.floor(ymin)) / 200.0
        scale *= 1.25
        
        r = 0
        
        if self.random_scale:
            scale = scale * (random.uniform(1 - self.scale_factor, 1 + self.scale_factor))
        
        if self.random_rotation : 
            r = random.uniform(-self.rot_factor, self.rot_factor) if random.random() <= 0.6 else 0
        
        if self.random_flip: 
            if random.random() <= 0.5:
                img = np.fliplr(img)
                pts = self.fliplr_joints(pts, width=img.shape[1])
                center[0] = img.shape[1] - center[0]
        
        img, scale_factor = self.crop(img, center, scale, [256, 256], rot=r)

        tpts = pts.copy()
        for i in range(pts.shape[0]):
            if tpts[i, 1] > 0 :
                tpts[i, 0:2] = self.transform_pixel(tpts[i, 0:2]+1, center, scale*scale_factor, [256,256], rot=r)
        
        img = img.reshape(256, 256, 1)
        img = (img - self.mean) / self.std
        img = img.transpose([2, 0, 1])
        img = torch.Tensor(img)
        
        # for [y1, x1, y2, x2, ...] order
        tpts = np.fliplr(tpts).flatten()
        tpts = torch.LongTensor(tpts)
        
        # [x1, y1, x2, y2, ...] order
        pts = torch.FloatTensor(pts).flatten()
        
        return img, tpts, pts, center, scale


    def fliplr_joints(self, landmark_coordinates, width):
        matched_parts = [[1, 2], [5, 7], [3, 4], [6, 8], [9, 10], [11, 12], 
                         [13, 15], [17, 18], [14, 16], [19, 20], [23, 24]]
        landmark_coordinates[:, 0] = width - landmark_coordinates[:, 0]
        for pair in matched_parts:
            tmp = landmark_coordinates[pair[0] - 1, :].copy()
            landmark_coordinates[pair[0] - 1, :] = landmark_coordinates[pair[1] - 1, :]
            landmark_coordinates[pair[1] - 1, :] = tmp
        return landmark_coordinates


    def get_transform(self, center, scale, output_size, rot=0):
        h = 200 * scale
        t = np.zeros((3, 3))
        t[0, 0] = float(output_size[1]) / h
        t[1, 1] = float(output_size[0]) / h
        t[0, 2] = output_size[1] * (-float(center[0]) / h + .5)
        t[1, 2] = output_size[0] * (-float(center[1]) / h + .5)
        t[2, 2] = 1
        
        if not rot == 0:
            rot = -rot
            rot_mat = np.zeros((3, 3))
            rot_rad = rot * np.pi / 180
            sn, cs = np.sin(rot_rad), np.cos(rot_rad)
            rot_mat[0, :2] = [cs, -sn]
            rot_mat[1, :2] = [sn, cs]
            rot_mat[2, 2] = 1
            t_mat = np.eye(3)
            t_mat[0, 2] = -output_size[1]/2
            t_mat[1, 2] = -output_size[0]/2
            t_inv = t_mat.copy()
            t_inv[:2, 2] *= -1
            t = np.dot(t_inv, np.dot(rot_mat, np.dot(t_mat, t)))
        return t


    def transform_pixel(self, pt, center, scale, output_size, invert=0, rot=0):
        t = self.get_transform(center, scale, output_size, rot=rot)
        if invert:
            t = np.linalg.inv(t)
        new_pt = np.array([pt[0] - 1, pt[1] - 1, 1.]).T
        new_pt = np.dot(t, new_pt)
        return new_pt[:2].astype(int) + 1


    def crop(self, img, center, scale, output_size, rot=0):
        center_new = center.clone()
        ht, wd = img.shape[0], img.shape[1]
        sf = scale * 200.0 / output_size[0]
        scale_adjustment = 1.
        
        if sf < 2:
            sf = 1
        else:
            new_size = int(np.math.floor(max(ht, wd) / sf))
            new_ht = int(np.math.floor(ht / sf))
            new_wd = int(np.math.floor(wd / sf))
            
            if new_size < 2:
                return torch.zeros(output_size[0], output_size[1], img.shape[2]) \
                            if len(img.shape) > 2 else torch.zeros(output_size[0], output_size[1])
            else:
                img = cv2.resize(img.astype(np.float32), (new_ht, new_wd))
                center_new[0] = center_new[0] * 1.0 / sf
                center_new[1] = center_new[1] * 1.0 / sf
                scale = scale / sf
        
        ul = np.array(self.transform_pixel([0, 0], center_new, scale, output_size, invert=1))
        br = np.array(self.transform_pixel(output_size, center_new, scale, output_size, invert=1))
        original_size = (br-ul)[0]
        
        if not rot == 0:
            pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2)
            ul -= pad
            br += pad
        
        new_shape = [br[1] - ul[1], br[0] - ul[0]]
        if len(img.shape) > 2:
            new_shape += [img.shape[2]]
        new_img = np.zeros(new_shape, dtype=np.float32)
        new_x = max(0, -ul[0]), min(br[0], len(img[0])) - ul[0]
        new_y = max(0, -ul[1]), min(br[1], len(img)) - ul[1]
        old_x = max(0, ul[0]), min(len(img[0]), br[0])
        old_y = max(0, ul[1]), min(len(img), br[1])
        new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1], old_x[0]:old_x[1]] / 255
        
        if not rot == 0:
            new_img = ndimage.rotate(new_img, rot)
            new_img = new_img[pad:-pad, pad:-pad]
            scale_adjustment = new_img.shape[0] / original_size
        
        new_img = cv2.resize(new_img.astype(np.float32), output_size)
        return new_img, scale_adjustment



def load_data(task, batch_size, random_scale, random_flip, random_rotation):
    path = './dataset/'
    train_set = COFW(path, True, random_scale, random_flip, random_rotation)
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0, drop_last=True)
    test_set = COFW(path, False, random_scale=False, random_flip=False, random_rotation=False)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=0, drop_last=True)
    return train_loader, test_loader




def load_state(train = False) : 
    path = "./pretrained_modules/"
    state = torch.load(path+"COFW_state.pth", map_location=device)
    
    leye_state = {
        'f_net' : state['f_net_leye'],
        'c_net' : state['c_net'],
        'preferred_z_f' : state['leye_z_f'],
        'preferred_z_c' : state['leye_z_c'] }
    
    reye_state = {
        'f_net' : state['f_net_reye'],
        'c_net' : state['c_net'],
        'preferred_z_f' : state['reye_z_f'],
        'preferred_z_c' : state['reye_z_c'] }
    
    others_state = {
        'f_net' : state['f_net_others'],
        'c_net' : state['c_net'],
        'relative_c_net' : state['relative_c_net'],
        'preferred_z_f' : state['others_z_f'],
        'preferred_z_c' : state['others_z_c'],
        'preferred_z_rc_22' : state['others_z_rc_22'] }
    
    return leye_state, reye_state, others_state



def list_mean(hist) : 
    return sum(hist) / len(hist)



def load_dirpolnet(dirpolnet_leye, dirpolnet_reye, dirpolnet_others) : 
    path = "./pretrained_modules/"
    state = torch.load(path + "dirPolNet.pth", map_location=device)
    dirpolnet_leye.load_state_dict(state['habit_leye'])
    dirpolnet_reye.load_state_dict(state['habit_reye'])
    dirpolnet_others.load_state_dict(state['habit_others'])
    return dirpolnet_leye, dirpolnet_reye, dirpolnet_others



def load_optuna_setting() : 
    path = "./pretrained_modules/"
    state = torch.load(path + "optuna_setting.pth", map_location=device)
    lambda_control_start = state['lambda_control_start']
    lambda_decrease = state['lambda_decrease']
    lambda_f_init = state['lambda_f_init']
    lambda_freq = state['lambda_freq']
    thr_control_start = state['thr_control_start']
    thr_increase = state['thr_increase']
    thr_init = state['thr_init']
    thr_freq = state['thr_freq']
    lambda_f_1stage = state['lambda_f_1stage']
    lambda_f_2stage = state['lambda_f_2stage']
    
    return lambda_control_start, lambda_decrease, lambda_f_init, lambda_freq,\
        thr_control_start, thr_increase, thr_init, thr_freq, lambda_f_1stage, lambda_f_2stage




def scheduler_step(opt, epoch, decay_interval, gamma) : 
    lr = []
    if epoch % decay_interval == 0 :
        for g in opt.param_groups:
            g['lr'] = g['lr']*gamma
            lr.append(g['lr'])
        print("=== learning rate decayed to {}.\n".format(lr))



def abs_coord_to_norm(c, img_size):
    return (2 * c / (torch.FloatTensor(img_size)-1).to(device)) - 1



def norm_coord_to_abs(c, img_size) : 
    return torch.round((c + 1) * ((torch.FloatTensor(img_size)-1).to(device) / 2))



def KL_div_from_mean(mean1, mean2) : 
    return 0.5*((mean1-mean2)**2)




def get_transform(center, scale, output_size, rot=0):
    h = 200 * scale
    t = np.zeros((3, 3))
    t[0, 0] = float(output_size[1]) / h
    t[1, 1] = float(output_size[0]) / h
    t[0, 2] = output_size[1] * (-float(center[0]) / h + .5)
    t[1, 2] = output_size[0] * (-float(center[1]) / h + .5)
    t[2, 2] = 1
    
    if not rot == 0:
        rot = -rot
        rot_mat = np.zeros((3, 3))
        rot_rad = rot * np.pi / 180
        sn, cs = np.sin(rot_rad), np.cos(rot_rad)
        rot_mat[0, :2] = [cs, -sn]
        rot_mat[1, :2] = [sn, cs]
        rot_mat[2, 2] = 1
        t_mat = np.eye(3)
        t_mat[0, 2] = -output_size[1]/2
        t_mat[1, 2] = -output_size[0]/2
        t_inv = t_mat.copy()
        t_inv[:2, 2] *= -1
        t = np.dot(t_inv, np.dot(rot_mat, np.dot(t_mat, t)))
    return t


def transform_pixel(pt, center, scale, output_size, invert=0, rot=0):
    t = get_transform(center, scale, output_size, rot=rot)
    if invert:
        t = np.linalg.inv(t)
    new_pt = np.array([pt[0] - 1, pt[1] - 1, 1.]).T
    new_pt = np.dot(t, new_pt)
    return new_pt[:2].astype(int) + 1



# for test : original space
def NME_calc(inferred_coord_abs_x_y, pts, center, scale) :
    pts = pts.view(-1, 29, 2)
    
    B = inferred_coord_abs_x_y.size(0)
    n_l = inferred_coord_abs_x_y.size(1)
    
    inferred_coord_abs_x_y = inferred_coord_abs_x_y.cpu().numpy()
    
    for b in range(B) :     
        for l in range(n_l) : 
            inferred_coord_abs_x_y[b, l] = transform_pixel(inferred_coord_abs_x_y[b, l]+1, center[b], scale[b], [256,256], invert=1, rot=0)
    
    inferred_coord_abs_x_y = inferred_coord_abs_x_y.reshape(-1, 29, 2)
    error = torch.norm(pts - inferred_coord_abs_x_y, dim=-1).mean(-1)
    
    inter_pupil_distance = torch.norm(pts[:, 17] - pts[:, 16], dim=-1)
    inter_ocular_distance = torch.norm(pts[:, 9] - pts[:, 8], dim=-1)
    
    NME_pupil = error / inter_pupil_distance
    NME_ocular = error / inter_ocular_distance
    return NME_pupil, NME_ocular



def NME_calc_landmarkwise(inferred_coord_abs_x_y, pts, center, scale) :
    pts = pts.view(-1, 29, 2)
    
    B = inferred_coord_abs_x_y.size(0)
    n_l = inferred_coord_abs_x_y.size(1)
    
    inferred_coord_abs_x_y = inferred_coord_abs_x_y.cpu().numpy()
    
    for b in range(B) :     
        for l in range(n_l) : 
            inferred_coord_abs_x_y[b, l] = transform_pixel(inferred_coord_abs_x_y[b, l]+1, center[b], scale[b], [256,256], invert=1, rot=0)
    
    inferred_coord_abs_x_y = inferred_coord_abs_x_y.reshape(-1, 29, 2)
    error = torch.norm(pts - inferred_coord_abs_x_y, dim=-1)  # [B, n_l]
    inter_pupil_distance = torch.norm(pts[:, 17] - pts[:, 16], dim=-1)
    inter_ocular_distance = torch.norm(pts[:, 9] - pts[:, 8], dim=-1)
    NME_pupil = error / inter_pupil_distance.view(-1, 1)
    NME_ocular = error / inter_ocular_distance.view(-1, 1)
    return NME_pupil, NME_ocular



# for test
def AUC_calc(NME_samples, failure_threshold=0.1, step=0.0001):
    NME_samples = torch.FloatTensor(NME_samples)
    nErrors = len(NME_samples)
    xAxis = torch.arange(0., failure_threshold+step, step)
    ced = [float(torch.count_nonzero(NME_samples <= x)) / nErrors for x in xAxis]
    AUC = simps(ced, x=xAxis) / failure_threshold
    failure_rate = 1. - ced[-1]
    return AUC, failure_rate
