from __future__ import print_function
import torch
import numpy as np

import math
import random
import h5py
import cv2
import scipy.ndimage as ndimage
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float


class COFW(torch.utils.data.Dataset):
    def __init__(self, data_path, is_train, random_scale, random_flip, random_rotation):
        self.is_train = is_train
        self.data_root = data_path
        self.scale_factor = 0.05
        self.rot_factor = 0
        self.random_scale = random_scale
        self.random_flip = random_flip
        self.random_rotation = random_rotation
        
        if is_train:
            self.f = h5py.File(data_path+"COFW_train.mat", 'r')
            self.images = self.f['IsTr']
            self.images = self.images[0]
            self.pts = self.f['phisTr']
        else:
            self.f = h5py.File(data_path+"COFW_test.mat", 'r')
            self.images = self.f['IsT']
            self.images = self.images[0]
            self.pts = self.f['phisT']
        
        self.mean = np.array([0.4637], dtype=np.float32)
        self.std = np.array([0.2591], dtype=np.float32)
    
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, index):
        image_ref = self.images[index]
        img = self.f[image_ref]
        img = np.array(img).T
        # COFW dataset : 29 landmarks. [[x0,y0], [x1,y1], ..., [x28,y28]] order.
        pts = np.transpose(self.pts)[index][0:58].reshape(2, -1).transpose()
        
        if len(img.shape) == 2:
            img = img.reshape(img.shape[0], img.shape[1], 1)
        
        xmin = np.min(pts[:, 0])
        xmax = np.max(pts[:, 0])
        ymin = np.min(pts[:, 1])
        ymax = np.max(pts[:, 1])
        center_w = (math.floor(xmin) + math.ceil(xmax)) / 2.0
        center_h = (math.floor(ymin) + math.ceil(ymax)) / 2.0
        center = torch.Tensor([center_w, center_h])   # x, y order
        scale = max(math.ceil(xmax) - math.floor(xmin), math.ceil(ymax) - math.floor(ymin)) / 200.0
        scale *= 1.25
        
        r = 0
        
        if self.random_scale:
            scale = scale * (random.uniform(1 - self.scale_factor, 1 + self.scale_factor))
        
        if self.random_rotation : 
            r = random.uniform(-self.rot_factor, self.rot_factor) if random.random() <= 0.6 else 0
        
        if self.random_flip: 
            if random.random() <= 0.5:
                img = np.fliplr(img)
                pts = self.fliplr_joints(pts, width=img.shape[1])
                center[0] = img.shape[1] - center[0]
        
        img, scale_factor = self.crop(img, center, scale, [256, 256], rot=r)

        tpts = pts.copy()
        for i in range(pts.shape[0]):
            if tpts[i, 1] > 0 :
                tpts[i, 0:2] = self.transform_pixel(tpts[i, 0:2]+1, center, scale*scale_factor, [256,256], rot=r)
        
        img = img.reshape(256, 256, 1)
        img = (img - self.mean) / self.std
        img = img.transpose([2, 0, 1])
        img = torch.Tensor(img)
        
        # for [y1, x1, y2, x2, ...] order
        tpts = np.fliplr(tpts).flatten()
        tpts = torch.LongTensor(tpts)
        return img, tpts


    def fliplr_joints(self, landmark_coordinates, width):
        matched_parts = [[1, 2], [5, 7], [3, 4], [6, 8], [9, 10], [11, 12], 
                         [13, 15], [17, 18], [14, 16], [19, 20], [23, 24]]
        landmark_coordinates[:, 0] = width - landmark_coordinates[:, 0]
        for pair in matched_parts:
            tmp = landmark_coordinates[pair[0] - 1, :].copy()
            landmark_coordinates[pair[0] - 1, :] = landmark_coordinates[pair[1] - 1, :]
            landmark_coordinates[pair[1] - 1, :] = tmp
        return landmark_coordinates


    def get_transform(self, center, scale, output_size, rot=0):
        h = 200 * scale
        t = np.zeros((3, 3))
        t[0, 0] = float(output_size[1]) / h
        t[1, 1] = float(output_size[0]) / h
        t[0, 2] = output_size[1] * (-float(center[0]) / h + .5)
        t[1, 2] = output_size[0] * (-float(center[1]) / h + .5)
        t[2, 2] = 1
        
        if not rot == 0:
            rot = -rot
            rot_mat = np.zeros((3, 3))
            rot_rad = rot * np.pi / 180
            sn, cs = np.sin(rot_rad), np.cos(rot_rad)
            rot_mat[0, :2] = [cs, -sn]
            rot_mat[1, :2] = [sn, cs]
            rot_mat[2, 2] = 1
            t_mat = np.eye(3)
            t_mat[0, 2] = -output_size[1]/2
            t_mat[1, 2] = -output_size[0]/2
            t_inv = t_mat.copy()
            t_inv[:2, 2] *= -1
            t = np.dot(t_inv, np.dot(rot_mat, np.dot(t_mat, t)))
        return t


    def transform_pixel(self, pt, center, scale, output_size, invert=0, rot=0):
        t = self.get_transform(center, scale, output_size, rot=rot)
        if invert:
            t = np.linalg.inv(t)
        new_pt = np.array([pt[0] - 1, pt[1] - 1, 1.]).T
        new_pt = np.dot(t, new_pt)
        return new_pt[:2].astype(int) + 1


    def crop(self, img, center, scale, output_size, rot=0):
        center_new = center.clone()
        ht, wd = img.shape[0], img.shape[1]
        sf = scale * 200.0 / output_size[0]
        scale_adjustment = 1.
        
        if sf < 2:
            sf = 1
        else:
            new_size = int(np.math.floor(max(ht, wd) / sf))
            new_ht = int(np.math.floor(ht / sf))
            new_wd = int(np.math.floor(wd / sf))
            
            if new_size < 2:
                return torch.zeros(output_size[0], output_size[1], img.shape[2]) \
                            if len(img.shape) > 2 else torch.zeros(output_size[0], output_size[1])
            else:
                img = cv2.resize(img.astype(np.float32), (new_ht, new_wd))
                center_new[0] = center_new[0] * 1.0 / sf
                center_new[1] = center_new[1] * 1.0 / sf
                scale = scale / sf
        
        ul = np.array(self.transform_pixel([0, 0], center_new, scale, output_size, invert=1))
        br = np.array(self.transform_pixel(output_size, center_new, scale, output_size, invert=1))
        original_size = (br-ul)[0]
        
        if not rot == 0:
            pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2)
            ul -= pad
            br += pad
        
        new_shape = [br[1] - ul[1], br[0] - ul[0]]
        if len(img.shape) > 2:
            new_shape += [img.shape[2]]
        new_img = np.zeros(new_shape, dtype=np.float32)
        new_x = max(0, -ul[0]), min(br[0], len(img[0])) - ul[0]
        new_y = max(0, -ul[1]), min(br[1], len(img)) - ul[1]
        old_x = max(0, ul[0]), min(len(img[0]), br[0])
        old_y = max(0, ul[1]), min(len(img), br[1])
        new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1], old_x[0]:old_x[1]] / 255
        
        if not rot == 0:
            new_img = ndimage.rotate(new_img, rot)
            new_img = new_img[pad:-pad, pad:-pad]
            scale_adjustment = new_img.shape[0] / original_size
        
        new_img = cv2.resize(new_img.astype(np.float32), output_size)
        return new_img, scale_adjustment



def load_data(task, batch_size, random_scale, random_flip, random_rotation):
    path = '../../dataset/'
    train_set = COFW(path, True, random_scale, random_flip, random_rotation)
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=True)
    test_set = COFW(path, False, random_scale=False, random_flip=False, random_rotation=False)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=4, drop_last=True)
    return train_loader, test_loader



def load_state() : 
    path = "../../checkpoint/"
    state = torch.load(path+"COFW_state.pth")
    leye_state = {
        'featnet' : state['featnet_leye'],
        'coordnet' : state['coordnet'],
        'prior_z_ft' : state['leye_z_ft'],
        'prior_z_cd' : state['leye_z_cd'] 
        }
    
    reye_state = {
        'featnet' : state['featnet_reye'],
        'coordnet' : state['coordnet'],
        'prior_z_ft' : state['reye_z_ft'],
        'prior_z_cd' : state['reye_z_cd'] 
        }
    
    others_state = {
        'featnet' : state['featnet_others'],
        'coordnet' : state['coordnet'],
        'relcoordnet' : state['relcoordnet'],
        'prior_z_ft' : state['others_z_ft'],
        'prior_z_cd' : state['others_z_cd'],
        'prior_z_rcd' : state['others_z_rcd']
        }
    
    return leye_state, reye_state, others_state



def make_batch(env, agent, n_landmarks, reference_c=None) : 
    with torch.no_grad() : 
        o0 = env.current_o
        l_idx = torch.randint(0, n_landmarks, (o0.size(0),)).to(device)
        
        if reference_c == None : 
            embedding_ft0, embedding_cd0, prior_ft, prior_cd, lambda_ft, direction\
                = agent.make_training_sample(env, o0, l_idx)
        else : 
            embedding_ft0, embedding_cd0, prior_ft, prior_cd, lambda_ft, direction\
                = agent.make_training_sample(env, o0, reference_c, l_idx)
       
    return embedding_ft0, embedding_cd0, prior_ft, prior_cd, lambda_ft, direction




def list_mean(hist) : 
    return sum(hist) / len(hist)



def abs_coord_to_norm(c, img_size=[256,256]):
    return (2 * c / (torch.FloatTensor(img_size)-1).to(device)) - 1



def norm_coord_to_abs(c, img_size=[256,256]) : 
    return torch.round((c + 1) * ((torch.FloatTensor(img_size)-1).to(device) / 2))


def save_model(names, dirpolnet_leye, dirpolnet_reye, dirpolnet_others, epoch):
    state = {
        'dirpolnet_leye': dirpolnet_leye.state_dict(),
        'dirpolnet_reye': dirpolnet_reye.state_dict(),
        'dirpolnet_others': dirpolnet_others.state_dict(),
        'epoch': epoch,
        }
    
    torch.save(state, '../../checkpoint/{}_{}epoch.pth'.format(names, epoch))

