import torch
import numpy as np
from torch.utils.data import Dataset
from typing import Callable, Optional
import gymnasium as gym
from umfavi.data.utils import prepare_episodes, extract_segments_from_episodes
from umfavi.utils.gym import get_undiscounted_return
from umfavi.types import DataKey, FeedbackType, Trajectory
from umfavi.utils.policies import Expert
from umfavi.utils.torch_utils import to_torch
import random


def print_rating_stats(segments: list[Trajectory], name: str) -> None:
    """Print statistics about the rating dataset segments.
    
    Args:
        segments: List of trajectory segments
        name: Dataset name for display
    """
    print("-"*80)
    print(f"DATA SUMMARY Ratings: {name}")
    print("-"*80)

    cum_rews = np.array([np.nansum(seg[DataKey.REWS]) for seg in segments])

    mean_reward = np.nanmean(cum_rews)
    std_reward = np.nanstd(cum_rews)
    min_reward = np.nanmin(cum_rews)
    max_reward = np.nanmax(cum_rews)

    print(f"Number of segments: {len(segments)}")
    print(f"Mean reward: {mean_reward:.1f} ± {std_reward:.1f} [{min_reward:.1f}, {max_reward:.1f}]")


def print_rating_diagnostics(
    segments: list[dict], 
    ratings: np.ndarray,
    cutpoints: np.ndarray,
    num_categories: int,
    name: str
) -> None:
    """Print detailed diagnostic statistics about ratings.
    
    Args:
        segments: List of segment dictionaries
        ratings: Array of assigned ordinal ratings
        cutpoints: Array of cutpoints used for rating assignment
        num_categories: Number of ordinal categories
        name: Dataset name for display
    """
    print("="*80)
    print(f"RATING DIAGNOSTICS: {name}")
    print("="*80)
    
    # Compute returns
    returns = np.array([np.nansum(seg[DataKey.REWS]) for seg in segments])
    
    # Validity ratios
    validity_ratios = np.array([np.mean(seg[DataKey.VALID]) for seg in segments])
    
    # 1. Return distribution
    print("\n[1] RETURN DISTRIBUTION")
    print(f"    Mean:   {np.mean(returns):.2f}")
    print(f"    Std:    {np.std(returns):.2f}")
    print(f"    Min:    {np.min(returns):.2f}")
    print(f"    Max:    {np.max(returns):.2f}")
    
    # 2. Cutpoints
    print(f"\n[2] CUTPOINTS (quantile-based)")
    print(f"    Number of categories: {num_categories}")
    print(f"    Cutpoint values: {cutpoints}")
    
    # Show quantile boundaries
    quantiles = [100 * (i+1) / num_categories for i in range(num_categories - 1)]
    for i, (q, c) in enumerate(zip(quantiles, cutpoints)):
        print(f"    θ_{i+1} ({q:.0f}th percentile): {c:.2f}")
    
    # 3. Rating distribution
    print(f"\n[3] RATING DISTRIBUTION")
    rating_counts = np.bincount(ratings, minlength=num_categories)
    for k in range(num_categories):
        pct = rating_counts[k] / len(ratings) * 100
        bar = "█" * int(pct / 2)
        print(f"    Category {k}: {rating_counts[k]:5d} ({pct:5.1f}%) {bar}")
    
    # 4. Returns per category
    print(f"\n[4] RETURNS PER CATEGORY")
    for k in range(num_categories):
        mask = ratings == k
        if np.any(mask):
            cat_returns = returns[mask]
            print(f"    Category {k}: {np.mean(cat_returns):.2f} ± {np.std(cat_returns):.2f} [{np.min(cat_returns):.2f}, {np.max(cat_returns):.2f}]")
    
    # 5. Segment validity
    print(f"\n[5] SEGMENT VALIDITY")
    print(f"    Mean validity ratio: {np.mean(validity_ratios):.4f}")
    print(f"    Min validity ratio:  {np.min(validity_ratios):.4f}")
    print(f"    Max validity ratio:  {np.max(validity_ratios):.4f}")
    
    low_validity = np.sum(validity_ratios < 0.5)
    print(f"    ⚠️  Segments with <50% valid steps: {low_validity} ({low_validity/len(validity_ratios)*100:.1f}%)")
    
    print("="*80)


def compute_quantile_cutpoints(returns: np.ndarray, num_categories: int) -> np.ndarray:
    """Compute cutpoints that partition returns into approximately equal bins.
    
    For K categories, we need K-1 cutpoints at quantiles 1/K, 2/K, ..., (K-1)/K.
    
    Args:
        returns: Array of trajectory returns
        num_categories: Number of ordinal categories (K)
    
    Returns:
        Array of K-1 cutpoints
    """
    quantiles = [(i + 1) / num_categories * 100 for i in range(num_categories - 1)]
    cutpoints = np.percentile(returns, quantiles)
    return cutpoints


def assign_ratings(returns: np.ndarray, cutpoints: np.ndarray) -> np.ndarray:
    """Assign ordinal ratings based on which bin each return falls into.
    
    Rating k is assigned if cutpoint[k-1] < return <= cutpoint[k]
    where cutpoint[-1] = -inf and cutpoint[K-1] = +inf by convention.
    
    Args:
        returns: Array of trajectory returns
        cutpoints: Array of K-1 cutpoints (strictly increasing)
    
    Returns:
        Array of ordinal ratings in [0, K-1]
    """
    # np.digitize returns the bin index where the value belongs
    # bins are: (-inf, cutpoints[0]], (cutpoints[0], cutpoints[1]], ..., (cutpoints[-1], +inf)
    ratings = np.digitize(returns, cutpoints, right=True)
    return ratings


class RatingDataset(Dataset):
    """
    Dataset for ordinal rating learning with trajectory segments.
    
    Ratings are assigned based on quantile-based cutpoints of trajectory returns,
    ensuring approximately equal samples per category.
    """
    def __init__(
        self, 
        num_episodes: int,  
        num_samples: int, 
        segment_len: int,
        policy: Expert,
        make_env_fn: Callable[[], gym.Env],
        device: str,
        base_seed: int,
        num_categories: int = 5,
        gamma: float = 0.99,
        obs_transform: Optional[Callable] = None,
        act_transform: Optional[Callable] = None,
        name: Optional[str] = "train",
        step_offset: int = 1,
        subsample_factor: int = 1,
        min_reward_threshold: Optional[float] = None,
        td_error_weight: float = 1.0
    ):
        """
        Initialize rating dataset.
        
        Args:
            num_episodes: Number of episodes to collect from policy
            num_samples: Number of rating samples to generate
            segment_len: Length of trajectory segments (None for full episodes)
            policy: Expert policy to collect trajectories from
            make_env_fn: Factory function for creating environment
            device: Device to place tensors on
            base_seed: Base seed for episode collection (needs to be collision-free with other datasets)
            num_categories: Number of ordinal categories (default 5 for Likert scale)
            gamma: Discount factor
            obs_transform: Optional observation transform
            act_transform: Optional action transform
            name: Dataset name for display
            step_offset: Offset for next_obs computation
            subsample_factor: Subsampling factor for trajectories
            min_reward_threshold: Optional minimum reward threshold for filtering
            td_error_weight: Weight for TD-error regularization
        """
        self.num_episodes = num_episodes
        self.num_samples = num_samples
        self.segment_len = segment_len
        self.make_env_fn = make_env_fn
        self.device = device
        self.obs_transform = obs_transform
        self.act_transform = act_transform
        self.name = name
        self.step_offset = step_offset
        self.subsample_factor = subsample_factor
        self.min_reward_threshold = min_reward_threshold
        self.base_seed = base_seed
        self.num_categories = num_categories
        self.gamma = gamma
        self.td_error_weight = td_error_weight
        self.generator = random.Random(base_seed)

        # Generate ratings
        self.data, self.cutpoints = self.generate_ratings(policy=policy)

        # Scalar attributes
        self._gamma = torch.tensor(gamma, dtype=torch.float32, device=device)
        self._td_error_weight = torch.tensor(td_error_weight, dtype=torch.float32, device=device)
    
    def generate_ratings(self, policy: Expert) -> tuple[dict, np.ndarray]:
        """
        Generate trajectory segments and assign ordinal ratings.
        
        Returns:
            Tuple of (data tensors, cutpoints) where:
            - data: Dict of tensors with shape (num_samples, segment_len, ...)
            - cutpoints: Array of K-1 cutpoints used for rating assignment
        """
        # Collect episodes to extract segments from
        episodes = prepare_episodes(
            policy=policy,
            num_episodes=self.num_episodes,
            make_env_fn=self.make_env_fn,
            base_seed=self.base_seed,
            step_offset=self.step_offset,
            subsample_factor=self.subsample_factor,
            obs_transform=self.obs_transform,
            act_transform=self.act_transform,
            min_reward_threshold=self.min_reward_threshold
        )

        # Extract segments
        if self.segment_len is None:
            print(f"Using full episodes for rating dataset. n_rating_samples will be ignored (n_episodes is used instead)")
            self.num_samples = len(episodes)
            segments = episodes
        else:
            segments = extract_segments_from_episodes(episodes, self.segment_len, self.num_samples, rng=self.generator)

        print_rating_stats(segments, self.name)

        # Compute returns for all segments (normalized by segment length to match decoder)
        segment_len = self.segment_len if self.segment_len is not None else max(len(seg[DataKey.REWS]) for seg in segments)
        returns = np.array([get_undiscounted_return(seg) / segment_len for seg in segments])
        
        # Compute quantile-based cutpoints
        cutpoints = compute_quantile_cutpoints(returns, self.num_categories)
        
        # Assign ordinal ratings based on cutpoints
        ratings = assign_ratings(returns, cutpoints)
        
        # Add ratings to segments
        for i, seg in enumerate(segments):
            seg[DataKey.RATING] = ratings[i]
        
        # Print diagnostics
        print_rating_diagnostics(segments, ratings, cutpoints, self.num_categories, self.name)
    
        # Create contiguous tensors of shape (num_samples, segment_len, ...)
        tensors = {
            k: to_torch(np.stack([seg[k] for seg in segments], axis=0), self.device) 
            for k in segments[0].keys()
        }
        
        return tensors, cutpoints

    def __len__(self):
        return self.data[DataKey.RATING].shape[0]
    
    def __getitem__(self, idx) -> dict:
        """
        Gets a single rating feedback sample.

        Returns:
            Dictionary with rating data.
        """
        # Scalars
        item_dict = {
            DataKey.FEEDBACK_TYPE: FeedbackType.RATING,
            DataKey.GAMMA: self._gamma,
            DataKey.TD_ERROR_WEIGHT: self._td_error_weight,
        }

        # Add remaining fields from data
        for k in self.data.keys():
            item_dict[k] = self.data[k][idx]
        
        return item_dict

