import torch
import numpy as np
from torch.utils.data import Dataset
from typing import Callable, Optional
import gymnasium as gym
from umfavi.utils.gym import get_undiscounted_return
from umfavi.types import DataKey, FeedbackType
from .utils import prepare_episodes
from umfavi.types import Trajectory
from umfavi.utils.torch_utils import to_torch


def print_demonstration_stats(episodes: list[Trajectory], name: str) -> None:
    """Print concise statistics about the demonstration dataset."""
    print("-"*80)
    print(f"DATA SUMMARY Demonstrations: {name}")
    print("-"*80)

    returns = [get_undiscounted_return(episode) for episode in episodes]
    lengths = [len(episode[DataKey.REWS]) for episode in episodes]
    returns = np.array(returns)
    lengths = np.array(lengths)
    print(f"Demos: {len(returns)} | Steps: {lengths.sum()}")
    print(f"Return: {returns.mean():.1f} ± {returns.std():.1f} [{returns.min():.1f}, {returns.max():.1f}]")
    print(f"Length: {lengths.mean():.1f} ± {lengths.std():.1f} [{lengths.min()}, {lengths.max()}]")


class DemonstrationDataset(Dataset):
    """
    Dataset for demonstration learning with expert trajectories.
    """
    def __init__(
        self, 
        num_demonstrations: int,   
        policy: Callable,
        make_env_fn: Callable[[], gym.Env],
        device: str,
        base_seed: int,
        subsample_factor: int = 1,
        beta: float = 1.0,
        gamma: float = 0.99,
        td_error_weight: float = 1.0,
        num_steps: Optional[int] = None,
        obs_transform: Optional[Callable] = None,
        act_transform: Optional[Callable] = None,
        name: Optional[str] = "train",
        step_offset: int = 1,
        min_reward_threshold: Optional[float] = None
    ):
        """
        Initialize demonstration dataset.
        
        Args:
            num_demonstrations: Number of demonstration trajectories to generate.
            n_steps: Length of each trajectory
            policy: Expert policy to generate demonstrations
            env: Gymnasium environment
            device: Device to store tensors on ('cpu' or 'cuda')
            rationality: Rationality parameter for expert policy
            gamma: Discount factor for Q-value computation
            td_error_weight: Weight for TD-error constraint in demonstrations
            num_steps: Number of time-steps (Optional).
                If not provided, the policy will be rolled out until `done` is received from the environment.
                If `done` is received before `num_steps` steps, the remaining datapoints will be padded with nan-equivalent value.
            obs_transform: Optional transformation for observations
            act_transform: Optional transformation for actions
            step_offset: Temporal offset of the consecutive (s, a) pair
            min_reward_threshold: If a sampled demonstration has a reward lower than this threshold, resample
        """
        self.num_demonstrations = num_demonstrations
        self.num_steps = num_steps
        self.make_env_fn = make_env_fn
        self.device = device
        self.base_seed = base_seed
        self.rationality = beta
        self.gamma = gamma
        self.td_error_weight = td_error_weight
        self.obs_transform = obs_transform
        self.act_transform = act_transform
        self.name = name
        self.step_offset = step_offset
        self.subsample_factor = subsample_factor
        self.min_reward_threshold = min_reward_threshold

        # Generate demonstrations
        self.data = self.generate_demonstrations(policy=policy)

        # Scalar attributes
        self._rationality = torch.tensor(self.rationality, dtype=torch.float32, device=device)
        self._gamma = torch.tensor(self.gamma, dtype=torch.float32, device=device)
        self._td_error_weight = torch.tensor(self.td_error_weight, dtype=torch.float32, device=device)
    
    def generate_demonstrations(self, policy: Callable) -> dict:
        """
        Generate expert demonstration trajectories.
        
        Returns:
            Dictionary with trajectory data using TrajectoryKey enums.
        """

        episodes = prepare_episodes(
            policy=policy,
            num_episodes=self.num_demonstrations,
            make_env_fn=self.make_env_fn,
            base_seed=self.base_seed,
            step_offset=self.step_offset,
            subsample_factor=self.subsample_factor,
            obs_transform=self.obs_transform,
            act_transform=self.act_transform,
            print_stat_fn=lambda eps: print_demonstration_stats(eps, self.name),
            min_reward_threshold=self.min_reward_threshold
        )

        # Assertion: Arrays belonging to the same trajectory should have the same length.
        # Save lengths for easier indexing later.
        self.episode_lengths = []
        for i, e in enumerate(episodes):
            first_len = len(e[DataKey.OBS])
            assert all([len(e[k]) == first_len for k in e.keys()]), f"Lengths of data for trajectory {i} don't match."
            self.episode_lengths.append(first_len)
        
        # Create contiguous tensors.
        tensors = {k: to_torch(np.concatenate([e[k] for e in episodes], axis=0), self.device) for k in episodes[0].keys()}
        return tensors
    
    def __len__(self):
        # Return total number of transitions (not trajectories)
        return int(self.data[DataKey.OBS].shape[0])
    
    def __getitem__(self, idx) -> dict:
        """
        Get a single (s, a, s', a') transition sample.
        Trajectories and time-steps are treated independently.
        
        Args:
            i: Index of the transition (0 to total_transitions - 1)
            
        Returns:
            Dictionary with demonstration data using SampleKey enums.
        """
        
        # Build the sample dictionary
        item_dict = {
            # Metadata
            DataKey.FEEDBACK_TYPE: FeedbackType.DEMONSTRATION,
            DataKey.RATIONALITY: self._rationality,
            DataKey.GAMMA: self._gamma,
            DataKey.TD_ERROR_WEIGHT: self._td_error_weight,
        }
            
        # Add remaining fields from data
        for k in self.data.keys():
            item_dict[k] = self.data[k][idx]
        
        return item_dict