# Copyright (c) Anonymous Organization.
# Inspired from https://github.com/gaoyuezhou/dino_wm
# Licensed under the MIT License

import abc
from typing import List, Optional, Sequence

import torch
from einops import rearrange
from torch import default_generator, randperm
from torch.utils.data import Dataset


# https://github.com/JaidedAI/EasyOCR/issues/1243
def _accumulate(iterable, fn=lambda x, y: x + y):
    "Return running totals"
    # _accumulate([1,2,3,4,5]) --> 1 3 6 10 15
    # _accumulate([1,2,3,4,5], operator.mul) --> 1 2 6 24 120
    it = iter(iterable)
    try:
        total = next(it)
    except StopIteration:
        return
    yield total
    for element in it:
        total = fn(total, element)
        yield total


class TrajDataset(Dataset, abc.ABC):
    @abc.abstractmethod
    def get_seq_length(self, idx):
        """
        Returns the length of the idx-th trajectory.
        """
        raise NotImplementedError


class TrajSubset(Dataset):
    """
    A cleaner implementation of trajectory subset that maintains direct access to samples.

    Args:
        dataset (TrajDataset): The source dataset
        indices (sequence): Indices in the whole set selected for subset
    """

    def __init__(self, dataset, indices):
        self.dataset = dataset
        self.indices = indices

        # Store direct reference to filtered samples for easier debugging
        if hasattr(dataset, "samples"):
            self.filtered_samples = [dataset.samples[i] for i in indices]
        else:
            self.filtered_samples = None

    def __getitem__(self, idx, subtask=None):
        return self.dataset.__getitem__(self.indices[idx], subtask=subtask)

    def __len__(self):
        return len(self.indices)

    def get_seq_length(self, idx):
        return self.dataset.get_seq_length(self.indices[idx])

    # Forward any other attribute access to the underlying dataset
    def __getattr__(self, name):
        return getattr(self.dataset, name)

    # Maintain proper serialization for multi-worker dataloaders
    def __getstate__(self):
        return self.dataset, self.indices

    def __setstate__(self, state):
        self.dataset, self.indices = state
        # Rebuild filtered_samples after deserialization
        if hasattr(self.dataset, "samples"):
            self.filtered_samples = [self.dataset.samples[i] for i in self.indices]
        else:
            self.filtered_samples = None


class TrajSlicerDataset(TrajDataset):
    def __init__(
        self,
        dataset: TrajDataset,
        num_frames: int,
        frameskip: int = 1,
        action_skip: int = 1,
        process_actions: str = "concat",
        generator: Optional[torch.Generator] = None,
    ):
        self.dataset = dataset
        self.num_frames = num_frames
        self.frameskip = frameskip
        self.action_skip = action_skip
        self.process_actions = process_actions
        self.slices = []
        for i in range(len(self.dataset)):
            T = self.dataset.get_seq_length(i)
            if T - num_frames < 0:
                print(f"Ignored short sequence #{i}: len={T}, num_frames={num_frames}")
            else:
                self.slices += [
                    (i, start, start + num_frames * self.frameskip) for start in range(T - num_frames * frameskip + 1)
                ]  # slice indices follow convention [start, end)
        # randomly permute the slices
        order = torch.randperm(len(self.slices), generator=generator).tolist()
        self.slices = [self.slices[i] for i in order]
        self.proprio_dim = self.dataset.proprio_dim
        if self.process_actions == "concat":
            if self.frameskip < self.action_skip:
                self.action_dim = self.dataset.action_dim
            else:
                self.action_dim = self.dataset.action_dim * (self.frameskip // self.action_skip)
        else:
            self.action_dim = self.dataset.action_dim

        self.state_dim = self.dataset.state_dim

    def get_seq_length(self, idx: int) -> int:
        return self.num_frames

    def __len__(self):
        return len(self.slices)

    def __getitem__(self, idx):
        """
        time dimension for act can be lower than obs and state, if self.frameskip < self.action_skip.
        """
        i, start, end = self.slices[idx]
        obs, act, state, reward, _ = self.dataset[i]
        # To avoid collator errors, create dummy tensors if state or reward from the traj_dset are None
        if reward is None:
            reward = torch.zeros(act.shape[0], dtype=torch.float32)
        for k, v in obs.items():
            obs[k] = v[start : end : self.frameskip]
        if state is None:
            state = torch.zeros(act.shape[0], dtype=torch.float32)
        act = act[start : end : self.action_skip]
        reward = reward[start : end : self.frameskip]
        state = state[start : end : self.frameskip]
        if self.frameskip < self.action_skip:
            act = rearrange(act, "(n f) d -> n (f d)", n=self.num_frames * self.frameskip // self.action_skip)
        else:
            if self.process_actions == "concat":
                act = rearrange(act, "(n f) d -> n (f d)", n=self.num_frames)  # concat actions
            elif self.process_actions == "sum":
                act = rearrange(act, "(n f) d -> n f d", n=self.num_frames)
                act = act.sum(dim=1)  # Sum along the frame dimension
        return tuple([obs, act, state, reward])


def random_split_traj(
    dataset: TrajDataset,
    lengths: Sequence[int],
    generator: Optional[torch.Generator] = default_generator,
    traj_subset: bool = True,
) -> List[TrajSubset]:
    if sum(lengths) != len(dataset):  # type: ignore[arg-type]
        raise ValueError("Sum of input lengths does not equal the length of the input dataset!")
    indices = randperm(sum(lengths), generator=generator).tolist()
    print([indices[offset - length : offset] for offset, length in zip(_accumulate(lengths), lengths)])
    if traj_subset:
        return [
            TrajSubset(dataset, indices[offset - length : offset])
            for offset, length in zip(_accumulate(lengths), lengths)
        ]


def split_traj_datasets(dataset, train_fraction=0.95, random_seed=42, traj_subset=True):
    dataset_length = len(dataset)
    lengths = [
        int(train_fraction * dataset_length),
        dataset_length - int(train_fraction * dataset_length),
    ]
    train_set, val_set = random_split_traj(
        dataset,
        lengths,
        generator=torch.Generator().manual_seed(random_seed),
        traj_subset=traj_subset,
    )
    return train_set, val_set


def get_train_val_sliced(
    traj_dataset: TrajDataset,
    train_fraction: float = 0.9,
    random_seed: int = 42,
    num_frames: int = 10,
    num_frames_val: int = None,
    frameskip: int = 1,
    action_skip: int = 1,
    traj_subset: bool = True,
    process_actions: str = "concat",
):
    train, val = split_traj_datasets(
        traj_dataset,
        train_fraction=train_fraction,
        random_seed=random_seed,
        traj_subset=traj_subset,
    )
    train_slices = TrajSlicerDataset(
        train,
        num_frames,
        frameskip,
        action_skip,
        generator=torch.Generator().manual_seed(random_seed),
        process_actions=process_actions,
    )
    val_slices = TrajSlicerDataset(
        val,
        num_frames_val if num_frames_val else num_frames,
        frameskip,
        action_skip,
        generator=torch.Generator().manual_seed(random_seed),
        process_actions=process_actions,
    )
    return train, val, train_slices, val_slices
