import itertools
from typing import Dict, List, Tuple
from glob import glob
import copy
from dataclasses import dataclass, asdict, field
import pandas as pd
import random
import time
import pickle
import os
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from temporal_task_planner.constants.gen_sess_config.lookup import *
from temporal_task_planner.data_structures.action import Action
from temporal_task_planner.data_structures.state import State
from temporal_task_planner.data_structures.temporal_context import TemporalContext
from temporal_task_planner.utils.gen_sess_utils import load_session_json
from temporal_task_planner.utils.extract_from_session import *
from temporal_task_planner.utils.data_structure_utils import (
    construct_act_instance,
    construct_action,
    construct_place_instances_from_keyframe,
    construct_rigid_instances_from_keyframe,
)
from temporal_task_planner.utils.datasetpytorch_utils import (
    get_preference_pairs,
    get_session_pairs,
    get_session_path,
    get_state_action_from_useraction,
    get_userid_list,
    get_session_list,
    save_data,
    get_temporal_context,
)

pkg_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../"))


class DishwasherArrangeDataset(Dataset):
    """Descending order dataset of plates (token index: self.num_instances-1)
    Predict at the ACTION MASKS
    after ACT token (token index: self.num_instances)
    """

    def __init__(
        self,
        session_paths: List[str],
        context_history: int,
        pick_only: bool = True,
        num_sessions_limit: int = -1,
        max_place_poses: int = 100,
    ) -> None:
        self.session_paths = session_paths
        self.context_history = context_history
        self.pick_only = pick_only
        self.num_sessions_limit = num_sessions_limit
        self.max_place_poses = max_place_poses

        self.num_sessions = len(self.session_paths)

        # clip and normalize pos info
        self.min_pos = -3
        self.max_pos = 3
        self.norm_pos = {
            "mean": [-1.8851218, 0.646792, 0.58907884],
            "std": [0.62312627, 0.34161338, 0.46390563],
        }
        self.start_userid = {}
        self.sess_userid_list = self.get_sess_userid_list()
        self.is_action_available = False

    def get_sess_userid_list(self) -> List[Tuple[int, int]]:
        """
        Processing start useraction ids for the sessions
        excluding
            - settling and clearing actions
            - init action
        """
        sess_userid_list = []
        for session_id in range(len(self.session_paths)):
            _sess_userid_list, _start_userid = get_userid_list(
                session_path=self.session_paths[session_id]
            )
            sess_userid_list += [(session_id, userid) for userid in _sess_userid_list]
            self.start_userid[session_id] = _start_userid
        return sess_userid_list

    def set_action_available(self, *args) -> None:
        self.is_action_available = True

    def process_session_keyframes(
        self, session_id: int, userid: int
    ) -> Tuple[Dict, Dict]:
        """
        Args:
            session_id, userid : which are valid for processing inputs
        Returns:
            inputs: Dict of Instance attributes (timestep, category, pose, etc.)
                with values describing each instance
            outputs: Dict of pick_track_ids (a.k.a 'act'), init_pose and end_pose
                corresponding to each action_mask True
        """
        data = load_session_json(self.session_paths[session_id])
        if data is None:
            return
        original_context_budget = (
            min(userid - self.start_userid[session_id], self.context_history) + 1
        )
        context_budget = copy.deepcopy(original_context_budget)
        useraction_counter = 0
        temporal_context = TemporalContext()
        while (context_budget) and (useraction_counter <= userid):
            useraction = data["session"]["userActions"][userid - useraction_counter]
            useraction_counter += 1
            if useraction["actionType"] == "init":
                continue
            keyframe = data["session"]["keyframes"][useraction["startFrame"]]
            self.set_action_available(context_budget, original_context_budget)
            state, action = get_state_action_from_useraction(
                keyframe,
                useraction,
                self.is_action_available,
                self.pick_only,
            )
            temporal_context.states.insert(0, state)
            if self.is_action_available:
                temporal_context.actions.insert(0, action)
            context_budget -= 1
        assert len(temporal_context.states), " check if temporal context has states?! "
        inputs = temporal_context.process_states()
        targets = temporal_context.process_actions()
        return inputs, targets

    def transform_position(
        self,
        position: List[float],
        transform_pos_limits: bool = True,
        transform_pos_norm: bool = True,
    ) -> List[float]:
        """
        Transformation for 3D position
         - limit the extreme coordinates
         - zero mean distribution for learning
        TODO: test the function whether it is bijective?!
        """
        if transform_pos_limits:
            for i in range(3):
                if position[i] < self.min_limit:
                    position = self.min_limit
                elif position[i] > self.max_limit:
                    position = self.max_limit
            if transform_pos_norm:
                position[i] = (position[i] - self.norm_pos["mean"]) / self.norm_pos[
                    "std"
                ]
        return position

    def __len__(self) -> int:
        return len(self.sess_userid_list)

    def __getitem__(self, index: int) -> Tuple[Dict, Dict]:
        session_id, userid = self.sess_userid_list[index]
        inputs, targets = self.process_session_keyframes(session_id, userid)
        inputs_dict = {key: torch.tensor(inputs[key]) for key in inputs.keys()}
        targets_dict = {key: torch.tensor(targets[key]) for key in targets.keys()}
        return inputs_dict, targets_dict


class DishwasherArrangeSavedSession(DishwasherArrangeDataset):
    """
    Process one session without repetitions for symbolic accuracy
    First k steps where k < context_window will have shorter inputs
    and 1 target
    later sequences will have inputs of size of the context window
    and the final target.
    """

    def __init__(self, dataset=None, *args, **kwargs) -> None:
        super(DishwasherArrangeSavedSession, self).__init__(*args, **kwargs)

    def get_session_path(self, session_pathname: str) -> List:
        return [session_pathname]

    def set_action_available(
        self, curr_context_budget: int, orig_context_budget: int
    ) -> None:
        if curr_context_budget == orig_context_budget:
            self.is_action_available = True
        else:
            self.is_action_available = False

def input_pad_fn(batch):
    inputs, targets = zip(*batch)
    # measure the padding sequence length with object correspondence and target sequence lengths
    inputs_pad_len = [len(x["timestep"]) for x in inputs]
    src_key_padding_mask = [
        torch.zeros(input_len).bool() for input_len in inputs_pad_len
    ]
    input_tuple = [inputs[i].values() for i in range(len(inputs))]
    timestep, bb, pose, action_masks, is_real, category_token, instance_token = zip(
        *input_tuple
    )
    inputs_padded = {
        "timestep": pad_sequence(timestep, batch_first=True, padding_value=0),
        "category": pad_sequence(bb, batch_first=True, padding_value=0),
        "pose": pad_sequence(pose, batch_first=True, padding_value=0),
        "action_masks": pad_sequence(
            action_masks, batch_first=True, padding_value=False
        ),
        "is_real": pad_sequence(is_real, batch_first=True, padding_value=False),
        "category_token": pad_sequence(
            category_token, batch_first=True, padding_value=0
        ),
        "instance_token": pad_sequence(
            instance_token, batch_first=True, padding_value=0
        ),
        "src_key_padding_mask": pad_sequence(
            src_key_padding_mask, batch_first=True, padding_value=True
        ),
    }
    return inputs_padded, targets

def pad_fn(batch) -> Tuple[Dict, Dict]:
    inputs_padded, targets = input_pad_fn(batch)
    target_tuple = [targets[i].values() for i in range(len(targets))]
    action_instance, init_pose, end_pose = zip(*target_tuple)
    targets_collated = {
        "action_instance": list(action_instance),
        "init_pose": list(init_pose),
        "end_pose": list(end_pose),
    }
    return inputs_padded, targets_collated

def preference_classifier_pad_fn(batch):
    inputs, targets = input_pad_fn(batch)
    targets_collated = torch.tensor(targets) #, dim=0)
    return inputs, targets_collated

class PromptSituationDataset:
    """
    data/
    |-- pref 1
        |-- n/
            |--sess 1
            |-- sess 2
            ...
    |-- pref 2
        |-- sess 1
        |-- sess 2
        ...
    ...
    """

    def __init__(
        self,
        session_pairs: List[Tuple[str, str]],
        pick_only: bool = True,
        # root_folder: str,
        # list_num_objects_per_rack: List[int],
        # num_session_limit: int = 10,
    ) -> None:
        self.session_pairs = session_pairs
        self.context_history = 1000  # very large number
        # get_preference_pairs(
        # root_folder, list_num_objects_per_rack=[6], num_sessions_limit=10
        # )
        self.pick_only = pick_only
        self.is_action_available = True

    def process_session_keyframes(
        self,
        session_data,
        userid: int,
        start_userid: int = 0,
        pathname: str = '',
    ) -> Tuple[Dict, Dict]:
        """Changed function signature of the original function
        Args:
            session_data, userid : which are valid for processing inputs
        Returns:
            inputs: Dict of Instance attributes (timestep, category, pose, etc.)
                with values describing each instance
            outputs: Dict of pick_track_ids (a.k.a 'act'), init_pose and end_pose
                corresponding to each action_mask True
        """
        if session_data is None:
            return
        original_context_budget = min(userid - start_userid, self.context_history) + 1
        context_budget = copy.deepcopy(original_context_budget)
        useraction_counter = 0
        temporal_context = TemporalContext()
        while (context_budget) and (useraction_counter <= userid):
            useraction = session_data["session"]["userActions"][
                userid - useraction_counter
            ]
            useraction_counter += 1
            if useraction["actionType"] == "init":
                continue
            keyframe = session_data["session"]["keyframes"][useraction["startFrame"]]
            # self.set_action_available(context_budget, original_context_budget)
            state, action = get_state_action_from_useraction(
                keyframe,
                useraction,
                self.is_action_available,
                self.pick_only,
            )
            # if not len(state.place_instances):
            #     breakpoint()
            # random.shuffle(state.rigid_instances)
            # random.shuffle(state.place_instances)
            temporal_context.states.insert(0, state)
            if self.is_action_available:
                temporal_context.actions.insert(0, action)
            context_budget -= 1
        assert len(temporal_context.states), " check if temporal context has states?! "
        # saving temporal context for future use
        with open(pathname, 'wb') as  f:
            pickle.dump(temporal_context, f)

        return temporal_context

    def process_session(self, session_path: str, pick_only: bool = True):
        pathname = session_path.split('.')[0] + '.pkl'
        if os.path.exists(pathname):
            with open(pathname, 'rb') as  f:
                temporal_context = pickle.load(f)
                temporal_context.pick_only = pick_only
        else:
            session_data = load_session_json(session_path)
            start_userid = 0  # For new sessions without settling or clearing
            last_userid = len(session_data["session"]["userActions"]) - 1
            temporal_context = self.process_session_keyframes(
                session_data, userid=last_userid, pathname=pathname
            )
            temporal_context.pick_only = pick_only
        inputs = temporal_context.process_states()
        targets = temporal_context.process_actions()
        ## DEBUG
        # for act in targets['act']:
        #         if act is None: # and self.is_action_available:
        #             breakpoint()
        #     # for key, val in targets.items():
        #         # if key is None or val is None:
        #         # print(key, targets[key])
        inputs_dict = {key: torch.tensor(inputs[key]) for key in inputs.keys()}
        targets_dict = {key: torch.tensor(targets[key]) for key in targets.keys()}
        return inputs_dict, targets_dict  # , info

    def __len__(self):
        return len(self.session_pairs)

    def __getitem__(self, index):
        prompt_filepath, situation_filepath = self.session_pairs[index]
        self.is_action_available = False
        prompt_inputs, prompt_targets = self.process_session(
            session_path=prompt_filepath
        )
        self.is_action_available = True
        situation_inputs, situation_targets = self.process_session(
            session_path=situation_filepath
        )
        inputs = {"prompt": prompt_inputs, "situation": situation_inputs}
        targets = {"prompt": prompt_targets, "situation": situation_targets}
        return inputs, targets


class PromptSituationSinglePrefSingleObj(PromptSituationDataset):
    def __init__(self, dataset_path: str, pick_only: bool) -> None:
        self.dataset_path = dataset_path
        session_pairs = get_session_pairs(Path(pkg_root, dataset_path).as_posix())
        super().__init__(session_pairs, pick_only)


class PromptSituationPickPlace(PromptSituationDataset):
    def __init__(self,
        session_paths: Dict[str, List[str]],
        pick_only: bool,
        num_sessions_limit: int = 10000,
    ) -> None:
        self.context_history = 1000
        self.is_action_available = True
        self.pick_only = False
        self.session_paths = session_paths
        self.pairs = self.populate_pairs()

    def populate_pairs(self):
        # self.pair_count = {}
        self.pairs = []
        # self.temporal_contexts = {}
        for preference_folder, session_list in self.session_paths.items():
            # session_list = get_session_list(preference_folder, list_num_objects_per_rack, num_sessions_limit)
            temporal_contexts = []
            for session_path in session_list:
                temporal_context = get_temporal_context(session_path)
                if temporal_context is None:
                    pathname = session_path.split('.')[0] + '.pkl'
                    session_data = load_session_json(session_path)
                    start_userid = 0  # For new sessions without settling or clearing
                    last_userid = len(session_data["session"]["userActions"]) - 1
                    temporal_context = self.process_session_keyframes(
                        session_data, userid=last_userid, pathname=pathname
                    )
                temporal_contexts.append(temporal_context)
            # self.temporal_contexts[preference_folder] = temporal_contexts
            # self.pair_count[preference_folder] = len(temporal_contexts) * sum([len(temporal_context.states) for temporal_context in temporal_contexts])
            situation_temporal_context_ids_list = []
            for idx, temporal_context in enumerate(temporal_contexts):
                for userid in range(len(temporal_context.states)):
                    situation_temporal_context_ids_list.append((session_list[idx], userid))
            # situation_temporal_context_ids_list = [(session_list[idx], userid) for userid in range(len(temporal_context.states)) for idx, temporal_context in enumerate(temporal_contexts)]
            session_pairs = list(itertools.product(session_list, situation_temporal_context_ids_list))
            self.pairs += session_pairs
        # dump pairs to a file
        return self.pairs
    
    def get_temporal_context(self, session_path):
        pathname = session_path.split('.')[0] + '.pkl'
        with open(pathname, 'rb') as  f:
            temporal_context = pickle.load(f)
        return temporal_context

    # def populate_data_list(self):
    #     """TOO SLOW!!!"""
    #     data_starter = []
    #     for pair in self.session_pairs:
    #         prompt_filepath, situation_filepath = pair
    #         userid_list, _ = get_userid_list(situation_filepath)
    #         for userid in userid_list:
    #             data_starter.append(
    #                 [prompt_filepath, situation_filepath, userid] 
    #             )
    #             print([prompt_filepath, situation_filepath, userid])
    #     return data_starter

    def process_session_by_userid(self, session_path: str, userid: int, pick_only: bool = False) -> Dict:
        """current context only"""
        temporal_context = get_temporal_context(session_path)
        temporal_context.pick_only = pick_only
        if not len(temporal_context.actions):
            # redo temporal context calculation?!
            session_data = load_session_json(session_path)
            start_userid = 0  # For new sessions without settling or clearing
            last_userid = len(session_data["session"]["userActions"]) - 1
            temporal_context = self.process_session_keyframes(
                session_data, userid=last_userid, pathname=pathname
            )
        current_temporal_context = temporal_context.init_sub_copy(start_id=userid, end_id=userid+1)
        current_temporal_context.pick_only = pick_only
        inputs = current_temporal_context.process_states()
        targets = current_temporal_context.process_actions()
        assert (len(targets['act']) == 1 and self.pick_only) or (len(targets['act']) == 2 and not self.pick_only)
        inputs_dict = {key: torch.tensor(inputs[key]) for key in inputs.keys()}
        targets_dict = {key: torch.tensor(targets[key]) for key in targets.keys()}
        return inputs_dict, targets_dict

    def __len__(self):
        # assert len(self.pairs) == sum(val for val in self.pair_count.values()) 
        return len(self.pairs)  # len(self.data_starter)

    def __getitem__(self, index):
        prompt_filepath, (situation_filepath, temporal_context_id) = self.pairs[index]
        assert  prompt_filepath.split("/")[-3] == situation_filepath.split("/")[-3], "Not the same preference for prompt and situation"
        prompt_inputs, prompt_targets = self.process_session(
            session_path=prompt_filepath, pick_only=True
        )
        # self.is_action_available = True
        situation_inputs, situation_targets = self.process_session_by_userid(
            session_path=situation_filepath, userid=temporal_context_id, pick_only=False
        )
        # breakpoint() # self.is_action_available = False
        inputs = {"prompt": prompt_inputs, "situation": situation_inputs}
        targets = {"prompt": prompt_targets, "situation": situation_targets}
        return inputs, targets


class SessionPreferenceDataset(PromptSituationDataset):
    def __init__(self, session_paths: List[str], pick_only: bool) -> None:
        self.session_paths = []
        for val in session_paths.values():
            self.session_paths += val
        self.pick_only = pick_only
        self.preference_label_vocab = Vocab()
        self.preference_label_vocab.word2index([sess.split('/')[-3] for sess in self.session_paths], train=True)
        self.context_history = 10000 # some large number to have entire session
        self.is_action_available = True

    def __len__(self):
        return len(self.session_paths)

    def __getitem__(self, index):
        filepath = self.session_paths[index]
        self.is_action_available = True
        inputs, actions = self.process_session(session_path=filepath)
        preference_label = filepath.split("/")[-3]
        preference = torch.tensor(self.preference_label_vocab.word2index(preference_label))
        return inputs, preference

def prompt_pad_fn(batch) -> Tuple[Dict, Dict]:
    inputs, targets = zip(*batch)
    input_tuple = [inputs[i].values() for i in range(len(inputs))]
    prompts, situations = zip(*input_tuple)
    prompt_tuple = [prompts[i].values() for i in range(len(prompts))]
    prompt_pad_len = [len(x["prompt"]["timestep"]) for x in inputs]
    src_key_padding_mask = [torch.zeros(_len).bool() for _len in prompt_pad_len]
    timestep, bb, pose, action_masks, is_real, category_token, instance_token = zip(
        *prompt_tuple
    )
    prompt_inputs_padded = {
        "timestep": pad_sequence(timestep, batch_first=True, padding_value=0),
        "category": pad_sequence(bb, batch_first=True, padding_value=0),
        "pose": pad_sequence(pose, batch_first=True, padding_value=0),
        "action_masks": pad_sequence(
            action_masks, batch_first=True, padding_value=False
        ),
        "is_real": pad_sequence(is_real, batch_first=True, padding_value=False),
        "category_token": pad_sequence(
            category_token, batch_first=True, padding_value=0
        ),
        "instance_token": pad_sequence(
            instance_token, batch_first=True, padding_value=0
        ),
        "src_key_padding_mask": pad_sequence(
            src_key_padding_mask, batch_first=True, padding_value=True
        ),
    }

    situation_tuple = [situations[i].values() for i in range(len(situations))]
    situation_pad_len = [len(x["situation"]["timestep"]) for x in inputs]
    src_key_padding_mask = [torch.zeros(_len).bool() for _len in situation_pad_len]
    timestep, bb, pose, action_masks, is_real, category_token, instance_token = zip(
        *situation_tuple
    )
    situation_inputs_padded = {
        "timestep": pad_sequence(timestep, batch_first=True, padding_value=0),
        "category": pad_sequence(bb, batch_first=True, padding_value=0),
        "pose": pad_sequence(pose, batch_first=True, padding_value=0),
        "action_masks": pad_sequence(
            action_masks, batch_first=True, padding_value=False
        ),
        "is_real": pad_sequence(is_real, batch_first=True, padding_value=False),
        "category_token": pad_sequence(
            category_token, batch_first=True, padding_value=0
        ),
        "instance_token": pad_sequence(
            instance_token, batch_first=True, padding_value=0
        ),
        "src_key_padding_mask": pad_sequence(
            src_key_padding_mask, batch_first=True, padding_value=True
        ),
    }

    target_tuple = [targets[i].values() for i in range(len(targets))]
    prompts, situations = zip(*target_tuple)
    situation_target_tuple = [situations[i].values() for i in range(len(situations))]
    action_instance, init_pose, end_pose = zip(*situation_target_tuple)
    targets_collated = {
        "action_instance": list(action_instance),
        "init_pose": list(init_pose),
        "end_pose": list(end_pose),
    }
    inputs_dict = {"prompt": prompt_inputs_padded, "situation": situation_inputs_padded}
    return inputs_dict, targets_collated


if __name__ == "__main__":
    import os
    import json

    pkg_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../"))
    """
    every target is of context window length or that of the userid 
        - to utilize parallelized training in Transformers 
    """
    # with open('session_split_num_prefs-4_num_obj-5,6,7.json') as f:
    #     session_paths= json.load(f)
    with open('session_list_num_pref-2_num_obj-6_small.json', 'r') as f:
        preference_session_paths = json.load(f)
        
    dataset = PromptSituationPickPlace(
        session_paths=preference_session_paths['val'],
        pick_only=False,
    )
    print(len(dataset))
    for index in [0, 1]:
        # print(index) 
        inputs, targets = dataset.__getitem__(index)
        print(inputs, targets) 
    """
    # Padding fn test!
    """
    loader = DataLoader(
        dataset,
        batch_size=256, #512, #128,
        shuffle=False,
        num_workers=4,
        collate_fn=prompt_pad_fn,
    )
    # inp, tgt = next(iter(loader))
    start_time = time.time()
    for idx, (inp, tgt) in enumerate(loader):
        print(idx)
    end_time = time.time()
    print('1st loop: ', end_time - start_time)

    start_time = time.time()
    for idx, (inp, tgt) in enumerate(loader):
        print(idx)
    end_time = time.time()
    print('2nd loop: ',end_time - start_time)
    print(inp, tgt)
    print("done")

    for name in ["full", "partial"]:
        dataset = DishwasherArrangeDataset(
            session_paths=get_session_path(
                os.path.join(
                    pkg_root, f"artifacts_sample/{name}_visibility/session_json/"
                )
            ),
            context_history=5,
            num_sessions_limit=5,
            pick_only=True,
        )
        # check if the self.sess_userid_list[index] for index from 0 to len
        # produces a valid session and user id
        for index in [0, 5, 7]:
            print(index)
            inputs, targets = dataset.__getitem__(index)
            print("inputs timestep : ", inputs["timestep"])
            print("targets act : ", targets["act"])
            assert (
                len(targets["act"]) == min(index, dataset.context_history) + 1
            ), "number of targets == current and previous states till context window"
            assert inputs["timestep"][-1] == len(
                targets["act"]
            ), "every state should have a target pick instance"
            assert inputs["action_masks"].sum() == len(
                targets["act"]
            ), "number of act instance tokens should match number of targets"

        """
        every target is of size 1; apt padding for inputs
        """
        sess_dataset = DishwasherArrangeSavedSession(
            session_paths=[dataset.session_paths[0]],
            context_history=5,
            pick_only=True,
        )
        for index in [0, 1]:
            print(index)
            # print(inputs, targets)
            inputs, targets = sess_dataset.__getitem__(index)
            print("inputs timestep : ", inputs["timestep"])
            print("targets act : ", targets["act"])
            assert len(targets["act"]) == 1, "Single ACT token for eval needed!"
            assert (
                inputs["timestep"][-1] == inputs["timestep"][targets["act"]]
            ), "Target pick instance timestep should be the most recent timestep"
            assert inputs["action_masks"].sum() == len(
                targets["act"]
            ), "number of act instance tokens should match number of targets"
        """
        # Padding fn test!
        """
        loader = DataLoader(
            dataset,
            batch_size=4,
            shuffle=False,
            collate_fn=pad_fn,
        )
        inp, tgt = next(iter(loader))
        print(inp, tgt)
        print("done")
