from typing import List

from ltsgns_mp.envs.eval_iterator import EvalIterator
from ltsgns_mp.envs.train_iterator import get_train_iterator
from ltsgns_mp.util import keys
from ltsgns_mp.util.own_types import ConfigDict
from ltsgns_mp.envs.trajectory_collection import TrajectoryCollection


class Env:
    def __init__(self, config: ConfigDict, train_iterator_config: ConfigDict, evaluation_config: ConfigDict,
                 traj_dict: dict[str, TrajectoryCollection], device):
        self.config = config
        self.traj_dict = traj_dict
        self.trajectory_length = len(traj_dict[keys.TRAIN][0])
        self.device = device

        self.train_iterator = get_train_iterator(train_iterator_config, traj_dict[keys.TRAIN], device)
        eval_trajs = self.traj_dict[evaluation_config.eval_iterator.evaluation_split]
        self.eval_iterators = {}

        if evaluation_config.eval_only:
            context_sizes = evaluation_config.evaluator.context_test_sizes
        else:
            context_sizes = evaluation_config.evaluator.context_val_sizes

        for context_size in context_sizes:
            # have one eval iterator for each context size
            self.eval_iterators[context_size] = EvalIterator(evaluation_config.eval_iterator, eval_trajs,
                                                             context_size, device)


