import copy
from glob import glob
from typing import Dict, List, Tuple
import itertools
import pickle
import os
import torch
from temporal_task_planner.data_structures.action import Action
from temporal_task_planner.data_structures.state import State
from temporal_task_planner.data_structures.temporal_context import TemporalContext
from temporal_task_planner.utils.data_structure_utils import (
    construct_act_instance,
    construct_action,
    construct_place_instances_from_keyframe,
    construct_rigid_instances_from_keyframe,
)
from temporal_task_planner.utils.extract_from_session import *
from temporal_task_planner.utils.gen_sess_utils import load_session_json
from temporal_task_planner.data_structures.instance import Instance


def get_session_path(session_pathname: str, num_sessions_limit: int = 10000) -> List:
    session_list = glob(session_pathname + "/sess_*.json")
    session_list.sort(key=lambda x: int(x.split("/")[-1].split(".")[0][5:]))
    return session_list[:num_sessions_limit]

def get_session_list(folder_name, num_objects_per_rack, num_sessions_limit=10000):
    session_list = glob(f"{folder_name}/{num_objects_per_rack}/sess_*.json")
    session_list.sort(key=lambda x: int(x.split("/")[-1].split(".")[0][5:]))
    session_list = session_list[:num_sessions_limit]
    return session_list

def get_session_pairs(folder_name, list_num_objects_per_rack, num_sessions_limit=10000):
    session_list = []
    for num_objects_per_rack in list_num_objects_per_rack:
        session_list = get_session_list(folder_name, num_objects_per_rack, num_sessions_limit=10000)    
    session_pairs = list(itertools.product(session_list, session_list))
    return session_pairs

def get_temporal_context(session_path):
    pathname = session_path.split('.')[0] + '.pkl'
    if not os.path.exists(pathname):
        return None
    with open(pathname, 'rb') as  f:
        temporal_context = pickle.load(f)
    return temporal_context

def get_preference_pairs(
    preference_folders: List[str],
    list_num_objects_per_rack: List[int] = [5, 6],
    num_sessions_limit: int = 10000,
) -> List:
    session_pairs = []
    # preference_folders = glob(f"{root_folder}/pref_*")
    for folder_name in preference_folders:
        for num_objects_per_rack in list_num_objects_per_rack:
            session_pairs += get_session_pairs(f"{folder_name}/{num_objects_per_rack}/", num_sessions_limit)
    return session_pairs


def get_userid_list(session_path: str):
    sess_userid_list = []
    data = load_session_json(session_path)
    if data is None:
        return None
    startflag = False
    for userid, useraction in enumerate(data["session"]["userActions"]):
        if not startflag:
            # ignoring settling and clearing actions
            if useraction["articulatedObj"] == "ktc_dishwasher_:0000":
                startflag = True
                start_userid = userid
        # ignoring init actions
        if startflag and (useraction["actionType"] != "init"):
            sess_userid_list.append(userid)
    return sess_userid_list, start_userid


def save_data(save_dict, path_name):
    torch.save(save_dict, path_name)
    print(f'saving at {path_name}')
    return 

def get_attribute_batch(instances: List[Instance], attribute: str):
    return torch.tensor([getattr(instance, attribute) for instance in instances])


def get_state_action_from_useraction(
    keyframe: Dict,
    useraction: Dict,
    is_action_available: bool = True,
    pick_only: bool = False,
    relative_timestep: int = 0,
) -> Tuple[State, Action]:
    """Creates the state and action from starting keyframe of a useraction"""
    # assign from useraction
    track = get_instance_to_track(useraction)
    target_instance_poses = get_target_instance_poses(useraction)
    feasible_pick = get_feasible_picks(useraction)
    feasible_place = get_feasible_place(useraction)
    state = State(pick_only=pick_only)
    state.rigid_instances = construct_rigid_instances_from_keyframe(
        keyframe, feasible_pick, relative_timestep
    )
    state.place_instances = construct_place_instances_from_keyframe(
        keyframe,
        track,
        target_instance_poses["endPose"],
        feasible_pick,
        feasible_place,
        relative_timestep,
    )
    state.act_instances = construct_act_instance(
        is_action_available=is_action_available,
        is_action_to_be_predicted=is_action_available,
        relative_timestep=relative_timestep,
    )
    action = None
    if is_action_available:
        action = construct_action(useraction)
    return state, action


if __name__ == "__main__":
    get_preference_pairs("data/", list_num_objects_per_rack=[6], num_sessions_limit=10)
