import os
import json
from tqdm import tqdm

import math
import torch
import numpy as np
from .rescore import get_feature
#from .rescore import rescore_valid
from sklearn.preprocessing import QuantileTransformer


def quantile_strench(y_temp):
    '''
    y_temp: b, n, 1
    '''
    ys = y_temp.cpu().numpy()
    y_strenched = []
    for y in ys:
        qt = QuantileTransformer(output_distribution='uniform', n_quantiles=len(y))
        y_strenched.append(qt.fit_transform(y))
    y_strenched = np.stack(y_strenched, axis=0)
    y_temp = torch.from_numpy(y_strenched).to(y_temp.device)
    return y_temp

def load_test_gt_in_coco_format(cfg, args):
    if 'coco' in cfg.dataset['dataset']:
        from pycocotools.coco import COCO

        coco_gt = COCO(os.path.join(cfg.dataset['root'], 'annotations', 'person_keypoints_val2017.json'))
        coco_dt = coco_gt.loadRes(os.path.join(cfg.dataset['json_root'], cfg.dataset['test'], f'{args.hpe_model}.json'))
    elif 'crowd' in cfg.dataset['dataset']:
        from crowdposetools.coco import COCO

        coco_gt = COCO(os.path.join(cfg.dataset['root'], 'annotations', 'crowdpose_test.json'))
        coco_dt = coco_gt.loadRes(os.path.join(cfg.dataset['json_root'], cfg.dataset['test'], f'{args.hpe_model}.json'))
    elif 'exlpose' in cfg.dataset['dataset']:
        from crowdposetools.coco import COCO

        coco_gt = COCO(os.path.join(cfg.dataset['root'], 'Annotations', 'ExLPose-OC_test_{}.json'.format(args.hpe_dataset.upper())))
        # print(os.path.join(cfg.dataset['root'], 'Annotations', 'ExLPose-OC_test_{}.json'.format(args.hpe_dataset.upper())))
        # input()
        coco_dt = coco_gt.loadRes(os.path.join(cfg.dataset['json_root'], cfg.dataset['test'], f'{args.hpe_model}.json'))
        # print(os.path.join(os.path.join(cfg.dataset['json_root'], cfg.dataset['test'], f'{args.hpe_model}.json')))
        # input()
    else:
        print('Only COCO and CrowdPose is supported.')
        assert False
    return coco_gt, coco_dt

def score_correction_cfg(new, old, cfg):
    thre = cfg.test['threshold']
    multi = cfg.test['multiply']
    # if new < 0.7:
    #     return 
    if old < thre:
        return old
    else:
        if multi: return new*old
        else: return new + 1.0

# # coco: swahr 0.2, others 0.1, crowdpose: swahr 0.2+* old
# def score_correction_bottomup(new, old):
#     if old < 0.5:
#        return old
#     else:
#        return new * old

# # coco:0.6, crowdpose
# def score_correction_onestage(new, old):
#     if old < 0.5:
#        return old
#     else:
#        return new * old

# def score_correction_topdown(new, old):
#     if old < 0.5:
#         return old
#     else:
#         return new * old

def evaluate(cfg, coco_gt, res_file):
    if 'coco' in cfg.dataset['dataset']:
        from pycocotools.cocoeval import COCOeval
    elif 'crowd' or 'exl' in cfg.dataset['dataset']:
        from crowdposetools.cocoeval import COCOeval
    else:
        print('Only COCO and CrowdPose is supported.')
        assert False

    coco_dt = coco_gt.loadRes(res_file)
    coco_eval = COCOeval(coco_gt, coco_dt, 'keypoints')
    coco_eval.params.useSegm = None
    coco_eval.evaluate()
    coco_eval.accumulate()
    coco_eval.summarize()

def load_all_dts(cfg, coco_dt):
    all_dts = []
    image_ids = np.array(list(coco_dt.imgs.keys())).astype(int)
    for img_id in tqdm(image_ids):
        dts = coco_dt.loadAnns(coco_dt.getAnnIds(img_id))
        all_dts.extend(dts)
    
    joints = np.zeros((len(all_dts), cfg.dataset['num_joints'], 3))
    dtscores = np.zeros((len(all_dts), 1))
    for i, dt in enumerate(all_dts):
        kpts = dt['keypoints']
        joints[i, :, 0] = kpts[::3]
        joints[i, :, 1] = kpts[1::3]
        joints[i, :, 2] = kpts[2::3]
        dtscores[i, 0] = dt['score']

    # if 'ricoh3' in cfg.dataset['test'] or 'a7m3' in cfg.dataset['test']:
    #    joints = joints[:, :12, :]
   
    joints = torch.tensor(joints).float()
    features = get_feature(joints, cfg)
    dtscores = torch.tensor(dtscores)
    
    # visible scores 
    v = joints[:, :, -1]
    mask = v > 0.2
    masked_v =  np.where(mask, v, np.nan)
    visible_scores = np.nanmean(masked_v, axis=1)

    return features, dtscores, visible_scores, all_dts

def rescore_json(cfg, args, model):
    # prepare file and save path
    new_json = []
    res_file = os.path.join('output', cfg.test['save_path'], 'rescored.json')
    if not os.path.exists(os.path.join('output', cfg.test['save_path'])):
        os.mkdir(os.path.join('output', cfg.test['save_path']))

    # load coco_format
    coco_gt, coco_dt = load_test_gt_in_coco_format(cfg, args)
    x_data, s_data, visible_scores, all_dts = load_all_dts(cfg, coco_dt)
   
    datasize = len(x_data)
    batchsize = cfg.test['batch_size']
    posesize = cfg.test['pose_size']
    pose_per_batch = batchsize * posesize
    index = np.arange(datasize).astype(int)
  
    for i in tqdm(range(int(datasize/pose_per_batch)+1)):
        if i != int(datasize/pose_per_batch):
           x_temp = x_data[index[i*pose_per_batch:(i+1)*(pose_per_batch)]].cuda()
           s_temp = s_data[index[i*pose_per_batch:(i+1)*(pose_per_batch)]].cuda()
           v_s_temp = visible_scores[index[i*pose_per_batch:(i+1)*(pose_per_batch)]]
           dts = all_dts[i*pose_per_batch:(i+1)*(pose_per_batch)]
        else:
           x_temp = x_data[index[i*pose_per_batch:datasize]].cuda()
           s_temp = s_data[index[i*pose_per_batch:datasize]].cuda()
           v_s_temp = visible_scores[index[i*pose_per_batch:datasize]]
           dts = all_dts[i*pose_per_batch:datasize]
        
        model.eval()
        with torch.no_grad():
            new_scores = model(x_temp, s_temp)
        
        for dt, new_s, v_s in zip(dts, new_scores, v_s_temp):
            old_s = dt['score']
            dt['score'] = float(score_correction_cfg(new_s, old_s, cfg))
            # print(old_s, new_s, float(score_correction_cfg(new_s, old_s, cfg)))
            # input()
            # if 'top_down' in cfg.dataset['pkl_root']:
            #     dt['score'] = float(score_correction_topdown(new_s, old_s)) 
            # elif "bottom_up" in cfg.dataset['pkl_root']:
            #     dt['score'] = float(score_correction_bottomup(new_s, old_s)) 
            # else:
            #     dt['score'] = float(score_correction_onestage(new_s, old_s)) 
            # print(old_s, new_s)
            # input()
            new_json.append(dt)
          # dump new json

    with open(res_file, 'w') as f:
        json.dump(new_json, f, sort_keys=True, indent=4)
    
    # evaluate new json
    evaluate(cfg, coco_gt, res_file)


def order_acc(y, y_hat):
    pass