import sys
sys.path.append('..')

import glob
import pickle
from collections import defaultdict

from configuration.config import *
from data.hand_pca_transform import pose_to_pca
from data.load_interaction import get_interaction_segments

"""Load results generated by pigraph"""
def load_pigraph(result_dir):
    results = defaultdict(list)

    for interaction_dir in result_dir.iterdir():
        if interaction_dir.is_dir():
            interaction = interaction_dir.name
            for scene_dir in interaction_dir.iterdir():
                scene_name = scene_dir.name
                if scene_dir.is_dir():
                    for result_file in scene_dir.iterdir():
                        if result_file.name[-4:] == '.pkl':
                            combination_name = result_file.name[:-4]
                            with result_file.open('rb') as pkl_file:
                                smplx_params = pickle.load(pkl_file)
                                T = len(smplx_params['transl'])
                                for idx in range(T):
                                    interaction_param = {'scene': scene_name, 'interaction': interaction,
                                                         'gender': 'neutral', 'object_combination': combination_name}
                                    for key in smplx_params:
                                        if key in smplx_param_names:
                                            interaction_param[key] = smplx_params[key][[idx]].cpu()
                                    interaction_param['left_hand_pose'], interaction_param['right_hand_pose'] = \
                                        pose_to_pca(interaction_param['left_hand_pose'], interaction_param['right_hand_pose'], gender=interaction_param['gender'])

                                    results[scene_name + '_' + combination_name].append(interaction_param)
    # print(results.keys())
    return results

"""Load results generated using POSA"""
def load_posa(result_dir):
    results = defaultdict(list)

    for scene_dir in result_dir.iterdir():
        if scene_dir.is_dir():
            scene_name = scene_dir.name
            for result_file in scene_dir.iterdir():
                if result_file.name[-4:] == '.pkl':
                    interaction = result_file.name.split('.')[0]
                    with result_file.open('rb') as pkl_file:
                        smplx_params = pickle.load(pkl_file)
                        T = len(smplx_params)
                        for idx in range(T):
                            interaction_param = {'scene': scene_name}
                            for key in smplx_params[idx]:
                                if key in smplx_param_names:
                                    interaction_param[key] = smplx_params[idx][key]
                            results[interaction].append(interaction_param)
    print(results.keys())
    return results

""" Load pseudo ground truth PROX interaction data from test set"""
def load_prox():
    with open(Path.joinpath(project_folder, "data", 'test.pkl'), 'rb') as data_file:
        test_data = pickle.load(data_file)
    results = defaultdict(list)
    for interaction in interaction_names:
        interaction_data = get_interaction_segments(interaction.split('+'), test_data, mode='verb-noun')
        for record in interaction_data:
            # scene_name, sequence, frame_idx, smplx_param, interaction_labels, interaction_obj_idx = record
            scene_name = record['scene_name']
            atomics = interaction.split('+')
            # verbs = [atomic.split('-')[0] for atomic in atomics]
            # nouns = [atomic.split('-')[1] for atomic in atomics]
            obj_ids = [record['interaction_obj_idx'][record['interaction_labels'].index(atomic)] for atomic in atomics]
            combination_name = '+'.join([atomics[atomic_idx] + '-' + str(obj_ids[atomic_idx]) for atomic_idx in range(len(atomics))])
            wrong_combination = ['MPH1Library_sit down-chair-5', 'MPH1Library_step up-chair-6', 'MPH1Library_stand up-chair-6', 'MPH1Library_stand up-chair-5', 'MPH1Library_step down-chair-8']
            if (scene_name + '_' + combination_name) in wrong_combination:  # filter wrong records
                continue
            interaction_param = {'scene': scene_name, 'interaction': interaction,
                                 'object_combination': combination_name}
            interaction_param.update(record['smplx_param'])
            if not 'gender' in interaction_param:
                interaction_param['gender'] = 'neutral'
            interaction_param['left_hand_pose'] = interaction_param['left_hand_pose'][:, :num_pca_comps]
            interaction_param['right_hand_pose'] = interaction_param['right_hand_pose'][:, :num_pca_comps]
            results[scene_name + '_' + combination_name].append(interaction_param)

    print(results.keys())
    return results


"""Load results generated by our method."""
def load_results(result_dir):
    results = defaultdict(list)

    for interaction_dir in result_dir.iterdir():
        if interaction_dir.is_dir():
            interaction = interaction_dir.name
            for scene_dir in interaction_dir.iterdir():
                scene_name = scene_dir.name
                if scene_dir.is_dir():
                    for result_file in scene_dir.iterdir():
                        if result_file.name[-4:] == '.pkl':
                            combination_name = result_file.name[:-4]
                            with result_file.open('rb') as pkl_file:
                                smplx_params = pickle.load(pkl_file)
                                T = len(smplx_params)
                                for idx in range(T):
                                    interaction_param = {'scene': scene_name, 'interaction': interaction,
                                                         'gender': 'neutral', 'object_combination': combination_name}
                                    for key in smplx_params[idx]:
                                        if key in smplx_param_names:
                                            interaction_param[key] = smplx_params[idx][key]
                                    results[scene_name + '_' + combination_name].append(interaction_param)
    # print(results.keys())
    return results

# dict of interaction results from different sources. Used in render_results.py and eval_results.py
synthesis_results_dict = {
    'prox': load_prox(),
    'pigraph_no_penetration': load_pigraph(Path('/home/kaizhao/projects/scene_graph/results') / 'pigraph_normal'),
    'POSA_best1': load_results(Path('/home/kaizhao/projects/scene_graph/results') /'POSA_IPoser_best1'),
    'floor_eval_try1_pene20_noseed_lr0.01_optimization': load_results(Path('/home/kaizhao/projects/scene_graph/results') / 'two_stage' / 'floor_eval_try1_pene20_noseed_lr0.01' / 'optimization_after_get_body'),
}
