from __future__ import print_function
import torch
import torch.nn.functional as F
import numpy as np
import torchvision.transforms.functional as F_

import math
import random
import h5py
import cv2
import scipy.ndimage as ndimage
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float


class COFW(torch.utils.data.Dataset):
    def __init__(self, data_path, is_train, random_scale, random_flip, random_rotation):
        self.is_train = is_train
        self.data_root = data_path
        self.scale_factor = 0.05
        self.rot_factor = 0
        self.random_scale = random_scale
        self.random_flip = random_flip
        self.random_rotation = random_rotation
        
        if is_train:
            self.f = h5py.File(data_path+"COFW_train.mat", 'r')
            self.images = self.f['IsTr']
            self.images = self.images[0]
            self.pts = self.f['phisTr']
        else:
            self.f = h5py.File(data_path+"COFW_test.mat", 'r')
            self.images = self.f['IsT']
            self.images = self.images[0]
            self.pts = self.f['phisT']
        
        self.mean = np.array([0.4637], dtype=np.float32)
        self.std = np.array([0.2591], dtype=np.float32)
    
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, index):
        image_ref = self.images[index]
        img = self.f[image_ref]
        img = np.array(img).T
        # COFW dataset : 29 landmarks. [[x0,y0], [x1,y1], ..., [x28,y28]] order.
        pts = np.transpose(self.pts)[index][0:58].reshape(2, -1).transpose()
        
        if len(img.shape) == 2:
            img = img.reshape(img.shape[0], img.shape[1], 1)
        
        xmin = np.min(pts[:, 0])
        xmax = np.max(pts[:, 0])
        ymin = np.min(pts[:, 1])
        ymax = np.max(pts[:, 1])
        center_w = (math.floor(xmin) + math.ceil(xmax)) / 2.0
        center_h = (math.floor(ymin) + math.ceil(ymax)) / 2.0
        center = torch.Tensor([center_w, center_h])   # x, y order
        scale = max(math.ceil(xmax) - math.floor(xmin), math.ceil(ymax) - math.floor(ymin)) / 200.0
        scale *= 1.25
        
        r = 0
        
        if self.random_scale:
            scale = scale * (random.uniform(1 - self.scale_factor, 1 + self.scale_factor))
        
        if self.random_rotation : 
            r = random.uniform(-self.rot_factor, self.rot_factor) if random.random() <= 0.6 else 0
        
        if self.random_flip: 
            if random.random() <= 0.5:
                img = np.fliplr(img)
                pts = self.fliplr_joints(pts, width=img.shape[1])
                center[0] = img.shape[1] - center[0]
        
        img, scale_factor = self.crop(img, center, scale, [256, 256], rot=r)

        tpts = pts.copy()
        for i in range(pts.shape[0]):
            if tpts[i, 1] > 0 :
                tpts[i, 0:2] = self.transform_pixel(tpts[i, 0:2]+1, center, scale*scale_factor, [256,256], rot=r)
        
        img = img.reshape(256, 256, 1)
        img = (img - self.mean) / self.std
        img = img.transpose([2, 0, 1])
        img = torch.Tensor(img)
        
        # for [y1, x1, y2, x2, ...] order
        tpts = np.fliplr(tpts).flatten()
        tpts = torch.LongTensor(tpts)
        return img, tpts


    def fliplr_joints(self, landmark_coordinates, width):
        matched_parts = [[1, 2], [5, 7], [3, 4], [6, 8], [9, 10], [11, 12], 
                         [13, 15], [17, 18], [14, 16], [19, 20], [23, 24]]
        landmark_coordinates[:, 0] = width - landmark_coordinates[:, 0]
        for pair in matched_parts:
            tmp = landmark_coordinates[pair[0] - 1, :].copy()
            landmark_coordinates[pair[0] - 1, :] = landmark_coordinates[pair[1] - 1, :]
            landmark_coordinates[pair[1] - 1, :] = tmp
        return landmark_coordinates


    def get_transform(self, center, scale, output_size, rot=0):
        h = 200 * scale
        t = np.zeros((3, 3))
        t[0, 0] = float(output_size[1]) / h
        t[1, 1] = float(output_size[0]) / h
        t[0, 2] = output_size[1] * (-float(center[0]) / h + .5)
        t[1, 2] = output_size[0] * (-float(center[1]) / h + .5)
        t[2, 2] = 1
        
        if not rot == 0:
            rot = -rot
            rot_mat = np.zeros((3, 3))
            rot_rad = rot * np.pi / 180
            sn, cs = np.sin(rot_rad), np.cos(rot_rad)
            rot_mat[0, :2] = [cs, -sn]
            rot_mat[1, :2] = [sn, cs]
            rot_mat[2, 2] = 1
            t_mat = np.eye(3)
            t_mat[0, 2] = -output_size[1]/2
            t_mat[1, 2] = -output_size[0]/2
            t_inv = t_mat.copy()
            t_inv[:2, 2] *= -1
            t = np.dot(t_inv, np.dot(rot_mat, np.dot(t_mat, t)))
        return t


    def transform_pixel(self, pt, center, scale, output_size, invert=0, rot=0):
        t = self.get_transform(center, scale, output_size, rot=rot)
        if invert:
            t = np.linalg.inv(t)
        new_pt = np.array([pt[0] - 1, pt[1] - 1, 1.]).T
        new_pt = np.dot(t, new_pt)
        return new_pt[:2].astype(int) + 1


    def crop(self, img, center, scale, output_size, rot=0):
        center_new = center.clone()
        ht, wd = img.shape[0], img.shape[1]
        sf = scale * 200.0 / output_size[0]
        scale_adjustment = 1.
        
        if sf < 2:
            sf = 1
        else:
            new_size = int(np.math.floor(max(ht, wd) / sf))
            new_ht = int(np.math.floor(ht / sf))
            new_wd = int(np.math.floor(wd / sf))
            
            if new_size < 2:
                return torch.zeros(output_size[0], output_size[1], img.shape[2]) \
                            if len(img.shape) > 2 else torch.zeros(output_size[0], output_size[1])
            else:
                img = cv2.resize(img.astype(np.float32), (new_ht, new_wd))
                center_new[0] = center_new[0] * 1.0 / sf
                center_new[1] = center_new[1] * 1.0 / sf
                scale = scale / sf
        
        ul = np.array(self.transform_pixel([0, 0], center_new, scale, output_size, invert=1))
        br = np.array(self.transform_pixel(output_size, center_new, scale, output_size, invert=1))
        original_size = (br-ul)[0]
        
        if not rot == 0:
            pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2)
            ul -= pad
            br += pad
        
        new_shape = [br[1] - ul[1], br[0] - ul[0]]
        if len(img.shape) > 2:
            new_shape += [img.shape[2]]
        new_img = np.zeros(new_shape, dtype=np.float32)
        new_x = max(0, -ul[0]), min(br[0], len(img[0])) - ul[0]
        new_y = max(0, -ul[1]), min(br[1], len(img)) - ul[1]
        old_x = max(0, ul[0]), min(len(img[0]), br[0])
        old_y = max(0, ul[1]), min(len(img), br[1])
        new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1], old_x[0]:old_x[1]] / 255
        
        if not rot == 0:
            new_img = ndimage.rotate(new_img, rot)
            new_img = new_img[pad:-pad, pad:-pad]
            scale_adjustment = new_img.shape[0] / original_size
        
        new_img = cv2.resize(new_img.astype(np.float32), output_size)
        return new_img, scale_adjustment



def load_data(task, batch_size, random_scale, random_flip, random_rotation):
    path = '../../dataset/'
    train_set = COFW(path, True, random_scale, random_flip, random_rotation)
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=True)
    test_set = COFW(path, False, random_scale=False, random_flip=False, random_rotation=False)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=4, drop_last=True)
    return train_loader, test_loader



def extract_shifted_patches(imgs, starting_point, shifts, patch_size):
    B, C, H, W = imgs.shape
    r = patch_size // 2

    delta = torch.arange(-r, r + 1, device=imgs.device)

    idx_h_local = starting_point + delta[:, None].repeat(1, patch_size)
    idx_w_local = starting_point + delta[None, :].repeat(patch_size, 1)

    idx_h_local = idx_h_local.unsqueeze(0)
    idx_w_local = idx_w_local.unsqueeze(0)

    shift_h = shifts[:, 0].view(B, 1, 1)
    shift_w = shifts[:, 1].view(B, 1, 1)

    idx_h = (idx_h_local + shift_h) % H
    idx_w = (idx_w_local + shift_w) % W

    batch_idx = torch.arange(B, device=imgs.device).view(B, 1, 1).expand(B, patch_size, patch_size)

    x = imgs.squeeze(1)
    patches = x[batch_idx, idx_h, idx_w]
    return patches.unsqueeze(1)



def anchor_random_crop(imgs, landmark_coords) :
    B = imgs.size(0)
    view_size = [27, 27]
    scale_ratio = [1, 4, 10]
    view_size_2 = [view_size[0]*scale_ratio[1], view_size[1]*scale_ratio[1]]
    view2 = int((view_size_2[0]-1)/2)
    
    need_pad = view2 + 1
    pad_imgs = F.pad(imgs, (need_pad, need_pad, need_pad, need_pad), mode='constant', value=3)
    
    n_landmarks = landmark_coords.size(1)
    landmark_select = torch.randint(0, n_landmarks, (B,))
    anchor_coords = landmark_coords.view(B, n_landmarks, 2)[torch.arange(B), landmark_select]
    random_coords = torch.randint(0, 256, size=(B, 2)).to(device)
    
    anchor_o_0ch = extract_shifted_patches(pad_imgs, need_pad, anchor_coords, 27)
    anchor_o_1ch = F_.resize(extract_shifted_patches(pad_imgs, need_pad, anchor_coords, 27*4+1), view_size)
    anchor_o = torch.cat((anchor_o_0ch, anchor_o_1ch), dim=1).view(B, 2, 27, 27)
    
    random_o_0ch = extract_shifted_patches(pad_imgs, need_pad, random_coords, 27)
    random_o_1ch = F_.resize(extract_shifted_patches(pad_imgs, need_pad, random_coords, 27*4+1), view_size)
    random_o = torch.cat((random_o_0ch, random_o_1ch), dim=1).view(B, 2, 27, 27)
    
    a_l_distance = (anchor_coords.view(B,1,2).repeat(1,n_landmarks,1) - landmark_coords.view(B,n_landmarks,2)).to(device)
    a_l_relationship = ((torch.abs(a_l_distance) <= view2).sum(-1) == 2).float().to(device)
    r_l_distance = (random_coords.view(B,1,2).repeat(1,n_landmarks,1) - landmark_coords.view(B,n_landmarks,2)).to(device)
    r_l_relationship = ((torch.abs(r_l_distance) <= view2).sum(-1) == 2).float().to(device)
    return anchor_o, random_o, landmark_select, a_l_distance, a_l_relationship, r_l_distance, r_l_relationship



def landmark_o_crop(imgs, landmark_coords) : 
    B = imgs.size(0)
    view_size = [27, 27]
    scale_ratio = [1, 4, 10]
    view_size_2 = [view_size[0]*scale_ratio[1], view_size[1]*scale_ratio[1]]
    view2 = int((view_size_2[0]-1)/2)
    
    need_pad = view2 + 1
    pad_imgs = F.pad(imgs, (need_pad, need_pad, need_pad, need_pad), mode='constant', value=3)
    
    landmark_coords = landmark_coords.permute(1, 0, 2)
    n_landmarks = landmark_coords.size(0)
    pad_imgs_repeat = pad_imgs.repeat(n_landmarks, 1, 1, 1)
    o_0ch = extract_shifted_patches(pad_imgs_repeat, need_pad, landmark_coords.reshape(-1,2), 27)
    o_1ch = F_.resize(extract_shifted_patches(pad_imgs_repeat, need_pad, landmark_coords.reshape(-1,2), 27*4+1), view_size)
    landmark_o = torch.cat((o_0ch, o_1ch), dim=1).view(n_landmarks, B, 2, 27, 27)
    
    return landmark_o



def save_model(names, featnet_leye, featnet_reye, featnet_others, optimizer, epoch):
    state = {
        'featnet_leye': featnet_leye.state_dict(),
        'featnet_reye': featnet_reye.state_dict(),
        'featnet_others': featnet_others.state_dict(),
        'opt': optimizer.state_dict(),
        'epoch': epoch
        }
    torch.save(state, '../../checkpoint/' + '{}_{}epoch.pth'.format(names, epoch))
    
    

def scheduler_step(opt, epoch, decay_interval, gamma) : 
    lr = []
    if epoch % decay_interval == 0 :
        for g in opt.param_groups:
            g['lr'] = g['lr']*gamma
            lr.append(g['lr'])
    
        print("=== learning rate decayed to {}.\n".format(lr))

