# ------------------------------------------------------------------------------
# Copyright (c) Microsoft
# Licensed under the MIT License.
# Written by Zigang Geng (zigang@mail.ustc.edu.cn)
# ------------------------------------------------------------------------------

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import pickle
from tqdm import tqdm
import numpy as np

from sklearn.preprocessing import QuantileTransformer
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

from pycocotools.cocoeval import COCOeval as COCOEval
from crowdposetools.cocoeval import COCOeval as CrowdposeEval

JOINT_COCO_LINK_1 = [0, 0, 1, 1, 2, 3, 4, 5, 5, 5, 6, 6, 7, 8, 11, 11, 12, 13, 14]
JOINT_COCO_LINK_2 = [1, 2, 2, 3, 4, 5, 6, 6, 7, 11, 8, 12, 9, 10, 12, 13, 14, 15, 16]

JOINT_CROWDPOSE_LINK_1 = [12, 13, 13, 0, 1, 2, 3, 0, 1, 6, 7,  8,  9, 6, 0]
JOINT_CROWDPOSE_LINK_2 = [13,  0,  1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 7, 1]

JOINT_CROWDPOSE_LINK_1_cut_2 = [0, 1, 2, 3, 0, 1, 6, 7,  8,  9, 6, 0]
JOINT_CROWDPOSE_LINK_2_cut_2 = [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 7, 1]

# def cocopose2exlpose(poses):
#     '''
#     poses: list of nparray [(17, 3), (17,3), ..., (17,3)]
#     '''    
#     import numpy as np
#     def transform_pose(pose):
#        '''
#        coco =[0: nose                  crowd=[0: l.shou. -> 5
#               1: l.eye                        1: r.shou. -> 6
#               2: r.eye                        2: l.elb   -> 7
#               3: l.ear                        3: r.elb   -> 8
#               4: r.ear                        4: l.wrist -> 9
#               5: l.shou                       5: r.wrist -> 10
#               6: r.shou                       6: l.hip -> 11
#               7: l.elb                        7: r.hip -> 12
#               8: r.elb                        8: l.knee -> 13
#               9: l.wrist                      9: r.knee -> 14
#               10: r.wrist                     10: l.ankle -> 15
#               11: l.hip                       11: r.ankle -> 16
#               12: r.hip                       12: head 
#               13: l.knee                      13: neck
#               14: r.knee
#               15: l.ankle
#               16: r.ankle]
#        ''' 
#        new_pose = np.zeros((14, 3))
#        new_score = np.zeros(14)

#        nose = pose[0]
#        neck = (pose[5] + pose[6]) / 2.0
#        head = (nose[:2] - neck[:2]) + nose[:2]
#        new_pose[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]] = pose[[5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]]
#        new_pose[12][:2] = head
#        new_pose[12][2] = (nose[2] + neck[2]) / 2.0
#        new_pose[13] = neck
#        return new_pose
    
#     pose_list = []
#     for pose in poses:
#          pose_list.append(transform_pose(pose))
#     return np.stack(pose_list, axis=0)[:, :12, :]

all_crowdpose_trainval_pkls = [
    'pkls/bottom_up/crowdpose/trainval/cid.pkl',
    'pkls/bottom_up/crowdpose/trainval/dekr.pkl',
    'pkls/bottom_up/crowdpose/trainval/pinet.pkl',
    'pkls/bottom_up/crowdpose/trainval/swahr.pkl',
    'pkls/bottom_up/crowdpose/trainval/cid.pkl',
    'pkls/top_down/crowdpose/trainval/vit_base.pkl',
    'pkls/one_stage/crowdpose/trainval/grouppose.pkl',
    'pkls/one_stage/crowdpose/trainval/edpose.pkl'
]

# Data process for the RescoreNet
def read_rescore_data(args, cfg, is_train=True):
    # if 'coco' in cfg.dataset['train'] and ('ricoh3' in cfg.dataset['test'] or 'a7m3' in cfg.dataset['test']) and is_train:
    #     num_joints = 17
    # else:
    num_joints = cfg.dataset['num_joints']
    
    if is_train:
        if 'all' in cfg.dataset['train']:
           train_file = all_crowdpose_trainval_pkls
           x_train, y_train, s_train = get_all_joint_and_score(train_file, num_joints)
        else:
           train_file = os.path.join(cfg.dataset['pkl_root'], cfg.dataset['train'], f'{args.hpe_model}.pkl')
           x_train, y_train, s_train = get_joint_and_score(train_file, num_joints)
    else:
        train_file = os.path.join(cfg.dataset['pkl_root'], cfg.dataset['test'], f'{args.hpe_model}.pkl')
        x_train, y_train, s_train = get_joint_and_score(train_file, num_joints)
    
    # if 'coco' in cfg.dataset['train'] and ('ricoh3' in cfg.dataset['test'] or 'a7m3' in cfg.dataset['test']) and is_train:
    #    print('transfer to 14-2 keypoints')
    #    x_train = cocopose2exlpose(x_train)
    # elif is_train==False and ('ricoh3' in cfg.dataset['test'] or 'a7m3' in cfg.dataset['test']):
    #     x_train = x_train[:, :12, :]
  
    feature_train = get_feature(x_train, cfg)
    return feature_train, y_train, x_train, s_train, None


def get_all_joint_and_score(filename_list, num_joints):
    x_list, y_list, s_list = [], [], []
    for path in filename_list:
        x, y, s = get_joint_and_score(path, num_joints)
        #print(x.shape, y.shape, s.shape)
        x_list.append(x)
        #print([x.shape for x in x_list])
        y_list.append(y)
        s_list.append(s)
        #input()
    #print([x.shape for x in x_list], [x.shape for x in y_list], [x.shape for x in s_list])
    x = torch.cat(x_list, dim=0)
    y = torch.cat(y_list, dim=0)
    s = torch.cat(s_list, dim=0)
    return x, y, s

def get_joint_and_score(filename, num_joints):
    obj = pickle.load(open(filename, "rb"))

    posx, posy, pos_scores = [], [], []
    for i in tqdm(range(1, len(obj))):
        pose = list(np.concatenate(
            (obj[i][0], obj[i][1]), axis=1).reshape(3*num_joints))
        posx.append(pose)
        pos_scores.append(obj[i][2])
        if obj[i][-1] == 1:
            obj[i][-1] = 0
        posy.append(obj[i][-1])

    x = np.array(posx)
    y = np.array(posy)
    s = np.array(pos_scores)
    x = torch.tensor(x.reshape((-1, num_joints, 3)), dtype=torch.float)
    y = torch.tensor(y.reshape((-1, 1)), dtype=torch.float)
    s = torch.tensor(s.reshape((-1, 1)), dtype=torch.float)
    return x, y, s

# def get_normalized_shape(x):
#     valid = x[:, :, -1] >= 0.0 # b, n
#     point_x = x[:, :, 0] # b, n
#     point_y = x[:, :, 1] # b, n
#     point_x = np.where(valid, point_x, np.nan) 
#     point_y = np.where(valid, point_y, np.nan)

#     center_x = np.nanmean(point_x, axis=1)[:, None]
#     center_y = np.nanmean(point_y, axis=1)[:, None]

#     point_x_max = np.nanmax(point_x, axis=1)[:, None]
#     point_x_min = np.nanmin(point_x, axis=1)[:, None]
#     max_x_length = point_x_max - point_x_min + 1e-6

#     point_y_max = np.nanmax(point_y, axis=1)[:, None]
#     point_y_min = np.nanmin(point_y, axis=1)[:, None]
#     max_y_length = point_y_max - point_y_min + 1e-6

#     normalized_point_x = (point_x - center_x) / max_x_length 
#     normalized_point_y = (point_y - center_y) / max_y_length
    
#     normalized_point_x = np.where(valid, normalized_point_x, 0.0)
#     normalized_point_y = np.where(valid, normalized_point_y, 0.0)

#     normalized_shape = np.stack([normalized_point_x, normalized_point_y], axis=-1)    
#     return torch.tensor(normalized_shape, dtype=torch.float)

def get_feature(x, cfg):
    joint_abs = x[:, :, :2]
    vis = x[:, :, 2].cuda()
    
    if 'coco' in cfg.dataset['dataset']:
        joint_1, joint_2 = JOINT_COCO_LINK_1, JOINT_COCO_LINK_2
    elif 'crowd' or 'exl' in cfg.dataset['dataset']:
        joint_1, joint_2 = JOINT_CROWDPOSE_LINK_1, JOINT_CROWDPOSE_LINK_2
    else:
        raise ValueError(
            'Please implement flip_index for new dataset: %s.' % cfg.dataset['dataset'])
   
    #To get the Delta x Delta y
    joint_abs = joint_abs.cuda()
    joint_relate = joint_abs[:, joint_1] - joint_abs[:, joint_2]
    
    b, _, _ = joint_relate.shape
    joint_length = torch.sqrt((joint_relate)[:, :, 0]**2 +
                    (joint_relate)[:, :, 1]**2)

    joint_relate = joint_relate / (joint_length[:, :, np.newaxis]+1e-6)
    joint_relate = joint_relate.reshape(b, -1)
    joint_length = joint_length / joint_length.max()
  
    feature = [joint_relate, joint_length, vis]

    feature = torch.cat(feature, dim=1)
    return feature


# def get_feature(x, dataset):
#     joint_abs = x[:, :, :2]
#     vis = torch.from_numpy(x[:, :, 2]).float().cuda()

#     # if 'coco' in dataset:
#     #     joint_1, joint_2 = JOINT_COCO_LINK_1, JOINT_COCO_LINK_2
#     # elif 'crowd' or 'exlpose' in dataset:
#     #     joint_1, joint_2 = JOINT_CROWDPOSE_LINK_1, JOINT_CROWDPOSE_LINK_2
#     # else:
#     #     raise ValueError(
#     #         'Please implement flip_index for new dataset: %s.' % dataset)

#     #To get the Delta x Delta y
#     #joint_relate = joint_abs[:, joint_1] - joint_abs[:, joint_2]
#     joint_abs = torch.from_numpy(joint_abs).float().cuda()
#     joint_relate = joint_abs[:, :, None, :] - joint_abs[:, None, :, :]

#     b, n, _, c = joint_relate.shape
#     joint_relate = joint_relate.reshape(b, n*n, c)
#     joint_length = torch.sqrt((joint_relate)[:, :, 0]**2 +
#                     (joint_relate)[:, :, 1]**2)

#     joint_relate = joint_relate / (joint_length[:, :, np.newaxis]+1e-6)
#     joint_relate = joint_relate.reshape(b, -1)
#     joint_length = joint_length / joint_length.max()
  
#     feature = [joint_relate, joint_length, vis]

#     # feature = np.concatenate(feature, axis=1)
#     # feature = torch.tensor(feature, dtype=torch.float)
#     feature = torch.cat(feature, dim=1)
#     return feature

def quantile_strench(y_temp):
    '''
    y_temp: b, n, 1
    '''
    ys = y_temp.cpu().numpy()
    y_strenched = []
    for y in ys:
        qt = QuantileTransformer(output_distribution='uniform', n_quantiles=len(y))
        y_strenched.append(qt.fit_transform(y))
    y_strenched = np.stack(y_strenched, axis=0)
    y_temp = torch.from_numpy(y_strenched).to(y_temp.device)
    return y_temp

def cal_order_acc(y_pred, y_temp):
    y_pred = y_pred[:, :, 0].detach().cpu().numpy()
    y_temp = y_temp[:, :, 0].detach().cpu().numpy()
    
    acc_list = []
    for yp, yt in zip(y_pred, y_temp):
        yp_matrix = ((yp / (yp[:, None]+1e-10)) > 1)
        yt_matrix = ((yt / (yt[:,None]+1e-10)) > 1)
        acc_matrix = (yp_matrix == yt_matrix).astype(np.float32)       
        acc = (acc_matrix.sum() - acc_matrix.shape[0]) / (acc_matrix.shape[0]**2 - acc_matrix.shape[0]) 
        acc_list.append(acc)
    acc = np.array(acc_list).mean()
    return acc 

def train_core(x_data, y_data, s_data, optimizer, scheduler, model, loss_fn, batchsize, pose_size):
    datasize = len(x_data)
    l1_loss_sum, order_loss_sum, acc_sum = 0, 0, 0
    index = np.arange(datasize)
    np.random.shuffle(index)

    poses_per_batch = batchsize * pose_size
    for i in tqdm(range(int(datasize/poses_per_batch))):
        x_temp = x_data[index[i*poses_per_batch:(i+1)*(poses_per_batch)]].cuda()
        y_temp = y_data[index[i*poses_per_batch:(i+1)*(poses_per_batch)]].cuda()
        s_temp = s_data[index[i*poses_per_batch:(i+1)*(poses_per_batch)]].cuda()

        x_temp = x_temp.view(batchsize, pose_size, -1)
        y_temp = y_temp.view(batchsize, pose_size, -1)
        s_temp = s_temp.view(batchsize, pose_size, -1)

        model.train()
        optimizer.zero_grad()
        y_pred = model(x_temp, s_temp)
        
        #y_temp = quantile_strench(y_temp)
        l1_loss, order_loss = loss_fn(y_pred, y_temp)
        loss = l1_loss+order_loss

        loss.backward()
        optimizer.step()
        scheduler.step()
        #print(f"current lr:{scheduler.get_last_lr()[0]}")

        l1_loss_sum += l1_loss.item()
        order_loss_sum += order_loss.item()
        acc_sum += cal_order_acc(y_pred, y_temp)

    return l1_loss_sum/int(datasize/poses_per_batch), order_loss_sum / (int(datasize/poses_per_batch)), acc_sum / (int(datasize/poses_per_batch))


def train_test_core(x_data, x_test, y_data, s_data, train_shape, test_shape,  optimizer, scheduler, model, loss_fn, batchsize, pose_size):
    datasize = len(x_data)
    test_datasize = len(test_shape)
    l1_loss_sum, order_loss_sum, quantile_loss_sum, acc_sum = 0, 0, 0, 0
    index = np.arange(datasize)
    np.random.shuffle(index)
    model.train()

    poses_per_batch = batchsize * pose_size
    for i in tqdm(range(int(datasize/poses_per_batch))):
        # training set
        x_temp = x_data[index[i*poses_per_batch:(i+1)*(poses_per_batch)]].cuda()
        y_temp = y_data[index[i*poses_per_batch:(i+1)*(poses_per_batch)]].cuda()
        s_temp = s_data[index[i*poses_per_batch:(i+1)*(poses_per_batch)]].cuda()
        shape_temp_train = train_shape[index[i*poses_per_batch:(i+1)*(poses_per_batch)]].cuda()
       
        _, k, c = shape_temp_train.shape
        shape_temp_train = shape_temp_train.view(batchsize, pose_size, k, c)
        
        x_temp = x_temp.view(batchsize, pose_size, -1)
        y_temp = y_temp.view(batchsize, pose_size, -1)
        s_temp = s_temp.view(batchsize, pose_size, -1)

        y_pred = model(x_temp, s_temp)
        
        # test_time
        idx_test = torch.randint(0, test_datasize-1, (batchsize * pose_size,))
        test_temp = x_test[idx_test].cuda()
        shape_temp_test = test_shape[idx_test].cuda()
        
        test_temp = test_temp.view(batchsize, pose_size, -1)
        shape_temp_test = shape_temp_test.view(batchsize, pose_size, k, c)

        y_pred_test = model(test_temp, s_temp)
        
        l1_loss, order_loss, quantile_loss = loss_fn(y_pred, y_temp, shape_temp_train, shape_temp_test, y_pred_test)
        loss = l1_loss+order_loss+quantile_loss
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
       
        l1_loss_sum += l1_loss.item()
        order_loss_sum += order_loss.item()
        quantile_loss_sum += quantile_loss.item()
        acc_sum += cal_order_acc(y_pred, y_temp)

    return l1_loss_sum/int(datasize/poses_per_batch), order_loss_sum / (int(datasize/poses_per_batch)), quantile_loss_sum / (int(datasize/poses_per_batch)), acc_sum / (int(datasize/poses_per_batch))


def valid_core(x_data, y_data, s_data, model, loss_fn, batchsize, pose_size):
    datasize = len(x_data)
    l1_loss_sum, order_loss_sum, acc_sum  = 0., 0., 0.
    index = np.arange(datasize)
    np.random.shuffle(index)

    poses_per_batch = batchsize * pose_size
    for i in tqdm(range(int(datasize/poses_per_batch))):
        x_temp = x_data[index[i*poses_per_batch:(i+1)*(poses_per_batch)]].cuda()
        y_temp = y_data[index[i*poses_per_batch:(i+1)*(poses_per_batch)]].cuda()
        s_temp = s_data[index[i*poses_per_batch:(i+1)*(poses_per_batch)]].cuda()

        x_temp = x_temp.view(batchsize, pose_size, -1)
        y_temp = y_temp.view(batchsize, pose_size, -1)
        s_temp = s_temp.view(batchsize, pose_size, -1)
        
        model.eval()
        with torch.no_grad():
             y_pred = model(x_temp, s_temp)
        
        #y_temp = quantile_strench(y_temp)
        l1_loss, order_loss = loss_fn(y_pred, y_temp)
        l1_loss_sum += l1_loss.item()
        order_loss_sum += order_loss.item()
        acc_sum += cal_order_acc(y_pred, y_temp)

    return l1_loss_sum/int(datasize/poses_per_batch), order_loss_sum / (int(datasize/poses_per_batch)), acc_sum / (int(datasize/poses_per_batch))


def rescore_valid(cfg, PredictOKSmodel, temp, ori_scores):
    temp = np.array(temp)

    feature = get_feature(temp, cfg.DATASET.DATASET)
    feature = feature.cuda()

    scores = PredictOKSmodel(feature)
    scores = scores.cpu().numpy()
    scores[np.isnan(scores)] = 0
    mul_scores = scores*np.array(ori_scores).reshape(scores.shape)
    scores = [np.float(i) for i in list(scores)]
    mul_scores = [np.float(i) for i in list(mul_scores)]
    return mul_scores


# Get Rescore training data for RescoreNet
class COCORescoreEval(COCOEval):
    def __init__(self, cocoGt=None, cocoDt=None, iouType='segm'):
        COCOEval.__init__(self, cocoGt, cocoDt, iouType)
        self.summary = [['pose', 'pose_heatval', 'score', 'oks']]
    
    def evaluateImg(self, imgId, catId, aRng, maxDet):
        '''
        get predicted pose and oks score for single category and image
        change self.summary
        '''
        p = self.params
        if p.useCats:
            gt = self._gts[imgId,catId]
            dt = self._dts[imgId,catId]
        else:
            gt = [_ for cId in p.catIds for _ in self._gts[imgId,cId]]
            dt = [_ for cId in p.catIds for _ in self._dts[imgId,cId]]
        if len(gt) == 0 and len(dt) ==0:
            return None
        
        for g in gt:
            if g['ignore'] or (g['area']<aRng[0] or g['area']>aRng[1]):
                g['_ignore'] = 1
            else:
                g['_ignore'] = 0

        # sort dt highest score first, sort gt ignore last
        gtind = np.argsort([g['_ignore'] for g in gt], kind='mergesort')
        gt = [gt[i] for i in gtind]
        dtind = np.argsort([-d['score'] for d in dt], kind='mergesort')
        dt = [dt[i] for i in dtind[0:maxDet]]
        # load computed ious
        ious = self.ious[imgId, catId][:, gtind] if len(self.ious[imgId, catId]) > 0 else self.ious[imgId, catId]

        gtIg = np.array([g['_ignore'] for g in gt])
        mIds = []
        if not len(ious)==0:
            for gind, g in enumerate(gt):
                if gtIg[gind] == 1: continue
                # information about best match so far (m=-1 -> unmatched)
                iou = 0
                m = -1
                for dind, d in enumerate(dt):
                    # continue to next gt unless better match made
                    if ious[dind,gind] < iou:
                        continue
                    # if match successful and best so far, store appropriately
                    iou=ious[dind,gind]
                    m=dind
                mIds.append(m)
                dtkeypoint = np.array(dt[m]['keypoints']).reshape((17,3))
                dtscore = dt[m]['score']
                self.summary.append([dtkeypoint[:,:2], dtkeypoint[:,2:], dtscore, iou])
        
            # for dind, d in enumerate(dt):
            #     if dind not in mIds:
            #        dtkeypoint = np.array(d['keypoints']).reshape((17,3))
            #        dtscore = d['score']
            #        self.summary.append([dtkeypoint[:,:2], dtkeypoint[:,2:], dtscore, 0.0])

    def dumpdataset(self, data_file):
        pickle.dump(self.summary, open(data_file, 'wb'))


class CrowdRescoreEval(CrowdposeEval):
    def __init__(self, cocoGt=None, cocoDt=None, iouType='segm'):
        CrowdposeEval.__init__(self, cocoGt, cocoDt, iouType)
        self.summary = [['pose', 'pose_heatval', 'oks']]
    
    def evaluateImg(self, imgId, catId, aRng, maxDet):
        '''
        get predicted pose and oks score for single category and image
        change self.summary
        '''
        p = self.params
        if p.useCats:
            gt = self._gts[imgId, catId]
            dt = self._dts[imgId, catId]
        else:
            gt = [_ for cId in p.catIds for _ in self._gts[imgId, cId]]
            dt = [_ for cId in p.catIds for _ in self._dts[imgId, cId]]
        if len(gt) == 0 and len(dt) == 0:
            return None
        
        for g in gt:
            tmp_area = g['bbox'][2] * g['bbox'][3] * 0.53
            if g['ignore'] or (tmp_area < aRng[0] or tmp_area > aRng[1]):
                g['_ignore'] = 1
            else:
                g['_ignore'] = 0

        # sort dt highest score first, sort gt ignore last
        gtind = np.argsort([g['_ignore'] for g in gt], kind='mergesort')
        gt = [gt[i] for i in gtind]
        dtind = np.argsort([-d['score'] for d in dt], kind='mergesort')
        dt = [dt[i] for i in dtind[0:maxDet]]
        # load computed ious
        ious = self.ious[imgId, catId][:, gtind] if len(
            self.ious[imgId, catId]) > 0 else self.ious[imgId, catId]

        gtIg = np.array([g['_ignore'] for g in gt])
        mIds = []
        if not len(ious)==0:
            for gind, g in enumerate(gt):
                if gtIg[gind] == 1: continue
                # information about best match so far (m=-1 -> unmatched)
                iou = 0
                m = -1
                for dind, d in enumerate(dt):
                    # continue to next gt unless better match made
                    if ious[dind,gind] < iou:
                        continue
                    # if match successful and best so far, store appropriately
                    iou=ious[dind,gind]
                    m=dind
                mIds.append(m)
                dtkeypoint = np.array(dt[m]['keypoints']).reshape((14,3))
                dtscore = dt[m]['score']
                self.summary.append([dtkeypoint[:,:2], dtkeypoint[:,2:], dtscore, iou])

    def dumpdataset(self, data_file):
        pickle.dump(self.summary, open(data_file, 'wb'))




