import torch
import logging
from typing import Optional
from dataloaders.trajectory_loader import RelayKitchenVideoTrajectoryDataset
from utils import get_split_idx, get_train_val_idx_by_goals
from . import OBS_ELEMENT_GOALS, OBS_ELEMENT_INDICES


def get_goal_fn(
    cfg,
    goal_conditional: Optional[str] = None,
    goal_seq_len: Optional[int] = None,
    seed: Optional[int] = None,
    train_fraction: Optional[float] = None,
):
    if goal_conditional == "future":
        relay_traj = RelayKitchenTrajectoryDataset(
            cfg.env_vars.datasets.relay_kitchen, onehot_goals=True
        )
    elif goal_conditional == "video":
        relay_traj = RelayKitchenVideoTrajectoryDataset(
            cfg.env_vars.datasets.relay_kitchen_video, onehot_goals=True
            )
    #train_idx, val_idx = get_split_idx(  
        #len(relay_traj),
        #seed=seed,
        #train_fraction=train_fraction,
    #)  # same split is used for training
    train_idx, val_idx = get_train_val_idx_by_goals(
            dataset=relay_traj,
            random_seed=seed,
            train_fraction=train_fraction,
            )  # same split is used for training
    if cfg.eval_on == 'eval':
        use_idx = val_idx
    elif cfg.eval_on == 'train':
        use_idx = train_idx
    if goal_conditional is None:
        goal_fn = lambda state: None
    elif goal_conditional == "future":
        assert (
            goal_seq_len is not None
        ), "goal_seq_len must be provided if goal_conditional is 'future'"

        def goal_fn(state, goal_idx, frame_idx):
            logging.info(f"goal_idx: {use_idx[goal_idx]}")
            _, _, _, _, goal, goal_mask = relay_traj[use_idx[goal_idx]]  # seq_len x obs_dim
            return goal, goal_mask

    elif goal_conditional == "onehot":

        def goal_fn(state, goal_idx, frame_idx):
            if frame_idx == 0:
                logging.info(f"goal_idx: {use_idx[goal_idx]}")
            _, _, _, onehot_goals = relay_traj[use_idx[goal_idx]]  # seq_len x obs_dim
            return onehot_goals[min(frame_idx, len(onehot_goals) - 1)]
    elif goal_conditional == "video":
        assert (
            goal_seq_len is not None
        ), "goal_seq_len must be provided if goal_conditional is 'video'"
        def goal_fn(state, goal_idx, frame_idx):
            logging.info(f"goal_idx: {use_idx[goal_idx]}")
            _, _, _, _, goal, goal_mask = relay_traj[use_idx[goal_idx]]
            return goal, goal_mask

    return goal_fn
