from typing import List

import torch

from torch.utils.data import (
    ConcatDataset,
    WeightedRandomSampler,
)

from transfer.datasets.single_task_dataset import SingleTaskDataset


class MultiTaskDataset(ConcatDataset):
    def __init__(
        self,
        datasets: List[SingleTaskDataset],
    ):
        super().__init__(datasets)

        self.state_dim = self.datasets[0].state_dim
        self.act_dim = self.datasets[0].act_dim

    def create_weighted_sampler(self):
        weights = []
        for dataset in self.datasets:
            weights.append(torch.from_numpy(dataset.traj_lens / sum(dataset.traj_lens)))
        samples_weight = torch.concat(weights)
        samples_weight /= len(self.datasets)
        weighed_sampler = WeightedRandomSampler(samples_weight.float(), len(samples_weight))
        return weighed_sampler
