"""
Dataset that provides FULL state trajectories for trajectory encoding.

Unlike StitchedSequenceDataset which repeats s0 for all timesteps (designed for diffusion policy),
this dataset provides the complete state sequence [s0, s1, s2, ..., sT] needed for encoding
trajectory shape.
"""
import torch
import numpy as np
from collections import namedtuple

Batch = namedtuple('Batch', ['actions', 'conditions'])


class FullTrajectoryDataset(torch.utils.data.Dataset):
    """
    Dataset that provides full state trajectories from start to end.

    Each sample contains:
    - states: (horizon, state_dim) - full state sequence [s0, s1, ..., s_{T-1}]
    - actions: (horizon, action_dim) - full action sequence [a0, a1, ..., a_{T-1}]

    This is different from StitchedSequenceDataset which repeats s0 for conditioning.
    """

    def __init__(self, dataset_path, horizon_steps=None, device='cpu', max_n_episodes=None):
        """
        Args:
            dataset_path: Path to .npz file with 'states', 'actions', 'traj_lengths'
            horizon_steps: If None, use full trajectory length. If specified, only trajectories
                          matching this length are included.
            device: Device to load data to
            max_n_episodes: Maximum number of episodes to load
        """
        self.dataset_path = dataset_path
        self.horizon_steps = horizon_steps
        self.device = device

        # Load data
        if dataset_path.endswith('.npz'):
            data = np.load(dataset_path, allow_pickle=False)
        else:
            import pickle
            with open(dataset_path, 'rb') as f:
                data = pickle.load(f)

        states_np = data['states']  # (total_steps, state_dim)
        actions_np = data['actions']  # (total_steps, action_dim)
        traj_lengths = data['traj_lengths']  # (num_episodes,)

        if max_n_episodes is not None:
            traj_lengths = traj_lengths[:max_n_episodes]

        self.state_dim = states_np.shape[1]
        self.action_dim = actions_np.shape[1]

        # Split into individual trajectories
        self.trajectories = []
        self.episode_indices = []  # Track original episode index for each trajectory
        start_idx = 0

        for episode_idx, traj_len in enumerate(traj_lengths):
            end_idx = start_idx + traj_len

            # If horizon_steps specified, only include matching trajectories
            if horizon_steps is not None and traj_len != horizon_steps:
                start_idx = end_idx
                continue

            traj_states = states_np[start_idx:end_idx]  # (traj_len, state_dim)
            traj_actions = actions_np[start_idx:end_idx]  # (traj_len, action_dim)

            self.trajectories.append({
                'states': torch.from_numpy(traj_states).float(),
                'actions': torch.from_numpy(traj_actions).float()
            })
            self.episode_indices.append(episode_idx)  # Store original episode index

            start_idx = end_idx

        print(f"Loaded {len(self.trajectories)} trajectories from {dataset_path}")
        if len(self.trajectories) > 0:
            print(f"Trajectory length: {self.trajectories[0]['states'].shape[0]}")
            print(f"State dim: {self.state_dim}, Action dim: {self.action_dim}")

    def __len__(self):
        return len(self.trajectories)

    def __getitem__(self, idx):
        """
        Returns full trajectory.

        Returns:
            Batch(actions, conditions) where:
            - actions: (horizon, action_dim) - full action sequence
            - conditions['state']: (horizon, state_dim) - full STATE sequence (not repeated s0!)
        """
        traj = self.trajectories[idx]

        states = traj['states'].clone()  # (horizon, state_dim)
        actions = traj['actions'].clone()  # (horizon, action_dim)

        # Return in same format as StitchedSequenceDataset for compatibility
        conditions = {'state': states}

        return Batch(actions=actions, conditions=conditions)


class FullTrajectoryDatasetWithWindows(torch.utils.data.Dataset):
    """
    Like FullTrajectoryDataset but creates overlapping windows from trajectories.
    Each window contains the FULL state sequence within that window, not repeated s0.

    Use this if you want to train on shorter segments while still providing state evolution.
    """

    def __init__(self, dataset_path, horizon_steps, stride=None, device='cpu', max_n_episodes=None):
        """
        Args:
            dataset_path: Path to .npz file
            horizon_steps: Length of each window
            stride: Step size between windows. If None, uses horizon_steps (non-overlapping)
            device: Device to load data to
            max_n_episodes: Maximum number of episodes to load
        """
        self.dataset_path = dataset_path
        self.horizon_steps = horizon_steps
        self.stride = stride if stride is not None else horizon_steps
        self.device = device

        # Load data
        if dataset_path.endswith('.npz'):
            data = np.load(dataset_path, allow_pickle=False)
        else:
            import pickle
            with open(dataset_path, 'rb') as f:
                data = pickle.load(f)

        states_np = data['states']
        actions_np = data['actions']
        traj_lengths = data['traj_lengths']

        if max_n_episodes is not None:
            traj_lengths = traj_lengths[:max_n_episodes]

        self.state_dim = states_np.shape[1]
        self.action_dim = actions_np.shape[1]

        # Create windows from trajectories
        self.windows = []
        start_idx = 0

        for traj_len in traj_lengths:
            end_idx = start_idx + traj_len

            # Create windows within this trajectory
            for window_start in range(start_idx, end_idx - horizon_steps + 1, self.stride):
                window_end = window_start + horizon_steps

                window_states = states_np[window_start:window_end]
                window_actions = actions_np[window_start:window_end]

                self.windows.append({
                    'states': torch.from_numpy(window_states).float(),
                    'actions': torch.from_numpy(window_actions).float()
                })

            start_idx = end_idx

        print(f"Created {len(self.windows)} windows from {len(traj_lengths)} trajectories")
        print(f"Window length: {horizon_steps}, Stride: {self.stride}")

    def __len__(self):
        return len(self.windows)

    def __getitem__(self, idx):
        window = self.windows[idx]

        states = window['states'].clone()  # (horizon, state_dim) - FULL state sequence
        actions = window['actions'].clone()  # (horizon, action_dim)

        conditions = {'state': states}

        return Batch(actions=actions, conditions=conditions)
