import logging
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, Dataset, Subset
from pathlib import Path
import numpy as np
from typing import Union, Callable, Optional, Sequence, List, Any
from tqdm import tqdm
import abc
from torch import default_generator, randperm
from torch._utils import _accumulate
import logging
import einops

from utils import (
    shuffle_along_axis,
    transpose_batch_timestep,
    eval_mode,
    get_goal_name_in_order,
)

ALL_TASKS = ["bottom burner","top burner","light switch","slide cabinet","hinge cabinet","microwave","kettle"]

class TrajectoryDataset(Dataset, abc.ABC):
    """
    A dataset containing trajectories.
    TrajectoryDataset[i] returns: (observations, actions, mask)
        observations: Tensor[T, ...], T frames of observations
        actions: Tensor[T, ...], T frames of actions
        mask: Tensor[T]: 0: invalid; 1: valid
    """

    @abc.abstractmethod
    def get_seq_length(self):
        """
        Returns the length of the idx-th trajectory.
        """
        raise NotImplementedError

    @abc.abstractmethod
    def get_idx_length(self, idx):
        """
        Returns the length of the idx-th trajectory.
        """
        raise NotImplementedError

    @abc.abstractmethod
    def get_all_actions(self):
        """
        Returns all actions from all trajectories, concatenated on dim 0 (time).
        """
        raise NotImplementedError


class TrajectorySubset(TrajectoryDataset, Subset):
    """
    Subset of a trajectory dataset at specified indices.

    Args:
        dataset (TrajectoryDataset): The whole Dataset
        indices (sequence): Indices in the whole set selected for subset
    """

    def __init__(self, dataset: TrajectoryDataset, indices: Sequence[int]):
        Subset.__init__(self, dataset, indices)

    def get_seq_length(self):
        return self.dataset.get_seq_length()

    def get_idx_length(self, idx):
        return self.dataset.get_idx_length(self.indices[idx])

    def get_all_actions(self):
        return self.dataset.get_all_actions()


class MetaWorldVideoTrajectoryDataset(TensorDataset, TrajectoryDataset):
    def __init__(self, data_directory, device="cpu", onehot_goals=False):
        data_directory = Path(data_directory)

        observations = torch.from_numpy(np.load(data_directory / "obs_seqs.npy"))
        observations = einops.rearrange(observations, "b t ... -> t b ...")

        actions = torch.from_numpy(np.load(data_directory / "act_seqs.npy"))
        actions = einops.rearrange(actions, "b t ... -> t b ...")

        goals = torch.load(data_directory / "onehot_goals.pth")
        goals = einops.rearrange(goals, "b t ... -> t b ...")

        masks = torch.from_numpy(np.load(data_directory / "existence_mask.npy"))
        masks = einops.rearrange(masks, "b t ... -> t b ...")
        
        videos = torch.from_numpy(np.load(data_directory / "image_seqs.npy"))
        videos = einops.rearrange(videos, "b t c w h -> b t w h c")

        video_masks = masks
        self.masks = masks
        self.video_masks = video_masks
        tensors = [observations, actions, goals, masks]
        tensors = [t.to(device).float() for t in tensors]
        tensors = tensors + [videos.to(device)] + [video_masks.to(device).float()]
        TensorDataset.__init__(self, *tensors)
        self.actions = self.tensors[1]

    def get_seq_length(self):
        obs_seq_lenth = self.masks.shape[1]
        goal_seq_lenth = self.video_masks.shape[1]
        print(f" goal seq lenth {goal_seq_lenth}, obs seq lenth {obs_seq_lenth}")
        return (goal_seq_lenth, obs_seq_lenth)

    def get_idx_length(self, idx):
        obs_seq_lenth = int(self.masks[idx].sum().item())
        goal_seq_lenth = int(self.video_masks[idx].sum().item())
        return (goal_seq_lenth, obs_seq_lenth)

    def get_all_actions(self):
        result = []
        # mask out invalid actions
        for i in range(len(self.masks)):
            T = int(self.masks[i].sum())
            result.append(self.actions[i, :T, :])
        return torch.cat(result, dim=0)

    def __getitem__(self, idx):
        T = self.masks[idx].sum().int().item()
        #return tuple(x[idx, :T] for x in self.tensors)
        return tuple((
            self.tensors[0][idx], #observations
            self.tensors[1][idx], #actions
            self.tensors[2][idx], #goals
            self.tensors[3][idx], #masks
            self.tensors[4][idx], #videos
            self.tensors[5][idx], #video_mask
            ))

class TrajectorySlicerDataset(TrajectoryDataset):
    def __init__(
        self,
        dataset: TrajectoryDataset,
        window: int,
        subgoal_lenth: int,
        future_conditional: bool = False,
        video_conditional: bool = False,
        video_directory: str = None,
        min_future_sep: int = 0,
        future_seq_len: Optional[int] = None,
        only_sample_tail: bool = False,
        transform: Optional[Callable] = None,
        video_dataset: Optional = None,
    ):
        """
        Slice a trajectory dataset into unique (but overlapping) sequences of length `window`.

        dataset: a trajectory dataset that satisfies:
            dataset.get_seq_length(i) is implemented to return the length of sequence i
            dataset[i] = (observations, actions, mask)
            observations: Tensor[T, ...]
            actions: Tensor[T, ...]
            mask: Tensor[T]
                0: invalid
                1: valid
        window: int
            number of timesteps to include in each slice
        future_conditional: bool = False
            if True, observations will be augmented with future observations sampled from the same trajectory
        min_future_sep: int = 0
            minimum number of timesteps between the end of the current sequence and the start of the future sequence
            for the future conditional
        future_seq_len: Optional[int] = None
            the length of the future conditional sequence;
            required if future_conditional is True
        only_sample_tail: bool = False
            if True, only sample future sequences from the tail of the trajectory
        transform: function (observations, actions, mask[, goal]) -> (observations, actions, mask[, goal])
        """
        if future_conditional or video_conditional:
            assert future_seq_len is not None, "must specify a future_seq_len"
        self.dataset = dataset
        self.window = window
        self.subgoal_lenth = subgoal_lenth
        self.future_conditional = future_conditional
        self.video_conditional = video_conditional
        self.video_dataset = video_dataset
        self.min_future_sep = min_future_sep
        self.future_seq_len = future_seq_len
        self.only_sample_tail = only_sample_tail
        self.transform = transform
        self.slices = []
        min_seq_length = np.inf
        for i in range(len(self.dataset)):  # type: ignore
            _, T_OBS = self.dataset.get_idx_length(i)  # avoid reading actual seq (slow)
            min_seq_length = min(T_OBS, min_seq_length)
            for j in range(T_OBS):
                traj_end_idx = j
                traj_start_idx = max(traj_end_idx - window + 1, 0)
                goal_start_idx = traj_start_idx
                goal_end_idx = min(goal_start_idx + window + subgoal_lenth - 1, T_OBS-1)
                self.slices.append((i, traj_start_idx, traj_end_idx + 1, goal_start_idx, goal_end_idx + 1))

    def get_seq_length(self) -> tuple:
        return self.dataset.get_seq_length()

    def get_idx_length(self, idx):
        return self.dataset.get_idx_length(idx)

    def get_all_actions(self) -> torch.Tensor:
        return self.dataset.get_all_actions()

    def __len__(self):
        return len(self.slices)

    def __getitem__(self, idx):
        # [observations, actions, goal, mask, video, video_mask]
        i, traj_start_idx, traj_end_idx, goal_start_idx, goal_end_idx = self.slices[idx]
        observation, action, _, mask, video, video_mask = self.dataset[i]

        traj_lenth = traj_end_idx - traj_start_idx
        traj_subgoal_lenth = goal_end_idx - goal_start_idx

        self_obs = torch.zeros((self.window, observation.shape[1]), dtype=torch.float)
        self_obs[:traj_lenth] = observation[traj_start_idx:traj_end_idx]

        img_obs = torch.zeros((self.window, video.shape[1], video.shape[2], video.shape[3]), dtype=torch.uint8)
        img_obs[:traj_lenth] = video[traj_start_idx:traj_end_idx]

        act = torch.zeros((self.window, action.shape[1]), dtype=torch.float)
        act[:traj_lenth] = action[traj_start_idx:traj_end_idx]

        obs_mask = torch.zeros((self.window), dtype=torch.float)
        obs_mask[:traj_lenth] = mask[traj_start_idx:traj_end_idx]

        subgoal = torch.zeros((self.window + self.subgoal_lenth, video.shape[1], video.shape[2], video.shape[3]), dtype=torch.uint8)
        subgoal[:traj_subgoal_lenth] = video[goal_start_idx:goal_end_idx]

        subgoal_mask = torch.zeros((self.window + self.subgoal_lenth), dtype=torch.float)
        subgoal_mask[:traj_subgoal_lenth] = video_mask[goal_start_idx:goal_end_idx]

        values = [self_obs, 
                  img_obs, 
                  act, 
                  obs_mask, 
                  subgoal,
                  subgoal_mask,
                  video,
                  video_mask,
                  ]
        return tuple(values)

def get_train_val_sliced_by_goal(
    traj_dataset: TrajectoryDataset,
    train_fraction: float = 0.9,
    random_seed: int = 42,
    device: Union[str, torch.device] = "cpu",
    window_size: int = 10,
    subgoal_lenth: int = 10,
    future_conditional: bool = False,
    video_conditional: bool = False,
    video_directory: str = None,
    min_future_sep: int = 0,
    future_seq_len: Optional[int] = None,
    only_sample_tail: bool = False,
    transform: Optional[Callable[[Any], Any]] = None,
):
    train, val = split_traj_datasets_by_goal(
        traj_dataset,
        train_fraction=train_fraction,
        random_seed=random_seed,
    )

    traj_slicer_kwargs = {
        "window": window_size,
        "subgoal_lenth": subgoal_lenth,
        "future_conditional": future_conditional,
        "video_conditional": video_conditional,
        "min_future_sep": min_future_sep,
        "future_seq_len": future_seq_len,
        "only_sample_tail": only_sample_tail,
        "transform": transform,
    }
    train_slices = TrajectorySlicerDataset(train, **traj_slicer_kwargs)
    val_slices = TrajectorySlicerDataset(val, **traj_slicer_kwargs)
    return train_slices, val_slices

def split_traj_datasets_by_goal(dataset, train_fraction=0.95, random_seed=42):
    dataset_length = len(dataset)
    goal_dict = {}
    for idx, data in enumerate(dataset):
        _, _, onehot_goals, _, _, _ = data
        goal_key = get_goal_name_in_order(onehot_goals)
        if len(goal_key) != 4:
            continue
        if goal_key in goal_dict:
            goal_dict[goal_key].append(idx)
        else:
            goal_dict[goal_key] = [idx]
    num_goal_types = len(goal_dict.keys())
    indices = randperm(num_goal_types, generator=torch.Generator().manual_seed(random_seed)).tolist()
    train_val_split = int((1 - train_fraction) * num_goal_types)
    train_goals = indices[:-train_val_split]
    val_goals = indices[-train_val_split:]
    assert len(train_goals) + len(val_goals) == num_goal_types, "sum of split should be the same as total"

    train_indices = []
    val_indices = []
    for i, key in enumerate(goal_dict.keys()):
        if i in train_goals:
            logging.info(f"Train goal: {key}")
            train_indices += goal_dict[key]
        elif i in val_goals:
            logging.info(f"Eval goal: {key}")
            val_indices += goal_dict[key]
    train_set = TrajectorySubset(dataset, train_indices)
    val_set = TrajectorySubset(dataset, val_indices)
    return train_set, val_set

def get_relay_kitchen_video_train_val_by_goals(
    data_directory,
    train_fraction=0.9,
    random_seed=42,
    device="cpu",
    window_size=10,
    subgoal_lenth=10,
    goal_conditional: Optional[str] = None,
    future_seq_len: Optional[int] = None,
    min_future_sep: int = 0,
):
    if goal_conditional is not None:
        assert goal_conditional in ["future", "onehot", "video"]
    print("get train val dataset by goal split")
    return get_train_val_sliced_by_goal(
        MetaWorldVideoTrajectoryDataset(
            data_directory, onehot_goals=(goal_conditional == "onehot")
        ),
        train_fraction,
        random_seed,
        device,
        window_size,
        subgoal_lenth,
        future_conditional=(goal_conditional == "future"),
        video_conditional=(goal_conditional == "video"),
        video_directory=Path(data_directory) / "videos",
        min_future_sep=min_future_sep,
        future_seq_len=future_seq_len,
    )
if __name__ == '__main__':
    pass