"""
Simplified LIBERO dataloader with lazy loading and basic functionality.
"""

import os

from pathlib import Path
from typing import Dict, Any, Tuple, Optional, List

from dataclasses import dataclass

import numpy as np
import torch
from torch.utils.data import Dataset
import tensorflow as tf
import tensorflow_datasets as tfds
from torchvision import transforms
from PIL import Image

@dataclass
class LiberoDataConfig:
    """Simple configuration for LIBERO dataset loading."""

    # Dataset parameters
    task_suite_name: str = "libero_spatial"
    data_root_path: str = "./data/libero"
    split: str = "train"
    debug: bool = False

    # Image processing
    image_size: Tuple[int, int] = (224, 224)

    # Action sequence parameters
    horizon: int = 20
    
    # Dataset dimensions
    state_dim: int = 8
    action_dim: int = 7


class LiberoDataset(Dataset):
    """LIBERO dataset with full memory preloading."""

    def __init__(self, config: LiberoDataConfig):
        self.config = config
        self.data_root_path = Path(config.data_root_path)
        self.task_suite_name = config.task_suite_name
        self.split = config.split
        self.image_size = config.image_size
        self.horizon = config.horizon

        # Image transforms
        self.image_transform = transforms.Compose(
            [
                transforms.Resize(self.image_size),
                transforms.ToTensor(),
            ]
        )

        # Initialize dataset loading
        self._init_full_loading()

        print(f"Initialized LIBERO dataset: {self.task_suite_name}")
        print(f"Split: {self.split}, Episodes: {self.total_episodes}")
        print(f"Total transitions: {self.total_transitions}")
        print("Loading mode: Full memory")

    def _init_full_loading(self):
        """Initialize full loading - preload all data into memory."""
        # Preload all episodes into memory
        self.all_episodes = []
        self.transition_to_episode = []  # Maps transition index to (episode_idx, step_idx)

        transition_count = 0
        total_episode_count = 0

        # Load data from the specified task suite
        suite_names = [self.task_suite_name]
        print("Loading full dataset into memory...")

        # Load data from each suite
        for suite_name in suite_names:
            dataset_name = f"{suite_name}_no_noops"
            data_path = self.data_root_path / dataset_name / "1.0.0"

            print(f"Loading suite: {suite_name}")

            # Load TensorFlow dataset
            builder = tfds.builder_from_directory(str(data_path))
            tf_dataset = builder.as_dataset(split=self.split)

            suite_episode_count = 0
            for episode_idx, episode_data in enumerate(tf_dataset):
                global_episode_idx = total_episode_count + episode_idx
                print(f"Loading episode {global_episode_idx + 1} from {suite_name}...", end="\r")

                if self.config.debug and global_episode_idx > 5:
                    break

                steps = list(episode_data["steps"])
                episode_length = len(steps)

                # Extract and process all data for this episode
                images = []
                states = []
                actions = []

                for step in steps:
                    # Process image
                    image = step["observation"]["image"]
                    if isinstance(image, tf.Tensor):
                        image = image.numpy()
                    images.append(image)

                    # Process state
                    state = step["observation"]["state"]
                    if isinstance(state, tf.Tensor):
                        state = state.numpy()
                    states.append(state)

                    # Process action
                    action = step["action"]
                    if isinstance(action, tf.Tensor):
                        action = action.numpy()
                    actions.append(action)

                # Store episode data
                episode = {
                    "images": np.array(images),
                    "states": np.array(states),
                    "actions": np.array(actions),
                    "episode_length": episode_length,
                    "task_description": step["language_instruction"].numpy().decode("utf-8"),
                    "suite_name": suite_name,  # Track which suite this episode came from
                }

                self.all_episodes.append(episode)

                # Map each valid starting point for horizon-length sequences
                # Only include transitions where we can get a full horizon of future actions
                for step_idx in range(max(0, episode_length - self.horizon + 1)):
                    self.transition_to_episode.append((global_episode_idx, step_idx))
                    transition_count += 1

                suite_episode_count += 1

            total_episode_count += suite_episode_count
            print(f"\nLoaded {suite_episode_count} episodes from {suite_name}")

        self.total_episodes = len(self.all_episodes)
        self.total_transitions = transition_count

        print(f"\nTotal loaded: {self.total_episodes} episodes with {self.total_transitions} total transitions into memory")

    def __len__(self) -> int:
        """Return total number of transitions."""
        return self.total_transitions

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        """Get a single transition with action sequence of length horizon."""
        episode_idx, step_idx = self.transition_to_episode[idx]
        episode = self.all_episodes[episode_idx]

        # Get current observation
        image = episode["images"][step_idx]
        state = episode["states"][step_idx]
        task_description = episode["task_description"]

        # Get action sequence starting from current step
        end_idx = min(step_idx + self.horizon, episode["episode_length"])
        action_sequence = episode["actions"][step_idx:end_idx]

        # Pad action sequence if needed (when near end of episode)
        if len(action_sequence) < self.horizon:
            # Pad with the last action to reach horizon length
            last_action = action_sequence[-1] if len(action_sequence) > 0 else np.zeros_like(episode["actions"][0])
            padding_length = self.horizon - len(action_sequence)
            padding = np.tile(last_action, (padding_length, 1))
            action_sequence = np.concatenate([action_sequence, padding], axis=0)

        # Process image
        if isinstance(image, np.ndarray):
            image = Image.fromarray(image.astype(np.uint8))
        image = self.image_transform(image)

        # Convert to torch tensors
        state = torch.tensor(state, dtype=torch.float32)
        action_sequence = torch.tensor(action_sequence, dtype=torch.float32)  # Shape: (horizon, action_dim)

        return {
            "images": image,
            "states": state,
            "actions": action_sequence,
            "task_descriptions": task_description,
            "task_ids": torch.tensor(0, dtype=torch.long),  # Simplified
        }


class LiberoDataLoader:
    """Simple wrapper for creating PyTorch DataLoader."""

    def __init__(self, dataset: LiberoDataset, batch_size: int = 32, shuffle: bool = True, num_workers: int = 0):

        from torch.utils.data import DataLoader

        self.dataloader = DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=shuffle,
            num_workers=num_workers,
            pin_memory=torch.cuda.is_available(),
        )

        print(f"Created LIBERO DataLoader:")
        print(f"  Batch size: {batch_size}")
        print(f"  Shuffle: {shuffle}")
        print(f"  Num workers: {num_workers}")

    def __iter__(self):
        return iter(self.dataloader)

    def __len__(self):
        return len(self.dataloader)
