import torch
import numpy as np
from torch.utils.data import Dataset
from typing import Callable, Optional
import gymnasium as gym
from umfavi.data.utils import prepare_episodes, extract_segments_from_episodes
from umfavi.utils.gym import get_undiscounted_return
from umfavi.types import DataKey, FeedbackType, Trajectory
from umfavi.utils.policies import Expert
from umfavi.utils.torch_utils import to_torch
import random


def softmax_1d(x: np.ndarray) -> np.ndarray:
    """Compute softmax for a 1D array with numerical stability."""
    x_shifted = x - np.max(x)
    exp_x = np.exp(x_shifted)
    return exp_x / np.sum(exp_x)


def sample_plackett_luce_ranking(returns: np.ndarray, beta: float, rng: random.Random) -> np.ndarray:
    """Sample a ranking from the Plackett-Luce distribution.
    
    The Plackett-Luce model generates rankings by sequentially selecting items
    with probability proportional to exp(beta * return). Higher beta means more
    deterministic rankings (items with higher returns are more likely to be ranked first).
    
    Args:
        returns: Array of returns for each segment, shape (k,)
        beta: Rationality parameter. Higher values = more deterministic rankings.
        rng: Random number generator for reproducibility
    
    Returns:
        Array of ranks for each segment, shape (k,). Rank 0 = best, rank k-1 = worst.
    """
    k = len(returns)
    remaining = list(range(k))
    ranking = []  # Will store indices in order of rank (best to worst)
    
    for _ in range(k):
        # Compute softmax probabilities over remaining items
        logits = beta * returns[remaining]
        probs = softmax_1d(logits)
        
        # Sample the next-best item
        chosen_idx = rng.choices(range(len(remaining)), weights=probs.tolist())[0]
        ranking.append(remaining.pop(chosen_idx))
    
    # Convert permutation to ranks
    # ranking[i] = index of item at position i (0 = best)
    # ranks[j] = position (rank) of item j
    ranks = np.zeros(k, dtype=np.int64)
    for rank, item in enumerate(ranking):
        ranks[item] = rank
    
    return ranks


def print_ranking_stats(episodes: list[Trajectory], name: str) -> None:
    """Print statistics about the ranking dataset episodes.
    
    Args:
        episodes: List of trajectory episodes
        name: Dataset name for display
    """
    print("-"*80)
    print(f"DATA SUMMARY Rankings: {name}")
    print("-"*80)

    cum_rews = np.array([np.nansum(ep[DataKey.REWS]) for ep in episodes])

    mean_reward = np.nanmean(cum_rews)
    std_reward = np.nanstd(cum_rews)
    min_reward = np.nanmin(cum_rews)
    max_reward = np.nanmax(cum_rews)

    print(f"Number of episodes: {len(episodes)}")
    print(f"Mean reward: {mean_reward:.1f} +/- {std_reward:.1f} [{min_reward:.1f}, {max_reward:.1f}]")


def print_ranking_diagnostics(
    rankings: list[dict],
    beta: float,
    num_ranked_items: int,
    name: str
) -> None:
    """Print detailed diagnostic statistics about rankings.
    
    Args:
        rankings: List of ranking dictionaries containing segments and ranks
        beta: Rationality parameter used for ranking generation
        num_ranked_items: Number of items per ranking (k)
        name: Dataset name for display
    """
    print("="*80)
    print(f"RANKING DIAGNOSTICS: {name}")
    print("="*80)
    
    print(f"\n[1] RANKING PARAMETERS")
    print(f"    Beta (rationality): {beta}")
    print(f"    Number of ranked items (k): {num_ranked_items}")
    print(f"    Number of ranking samples: {len(rankings)}")
    
    # Compute returns for all segments in all rankings
    all_returns = []
    return_spreads = []  # max - min return within each ranking
    validity_ratios = []
    
    # Kendall tau correlations with true ordering
    kendall_taus = []
    
    for ranking in rankings:
        rewards = ranking[DataKey.REWS]  # shape: (k, segment_len)
        valid = ranking[DataKey.VALID]    # shape: (k, segment_len)
        ranks = ranking[DataKey.RANKING]  # shape: (k,)
        
        # Returns per segment (normalized)
        segment_returns = np.array([np.nansum(rewards[i]) for i in range(num_ranked_items)])
        all_returns.extend(segment_returns.tolist())
        return_spreads.append(np.max(segment_returns) - np.min(segment_returns))
        
        # Validity
        for i in range(num_ranked_items):
            validity_ratios.append(np.mean(valid[i]))
        
        # Kendall tau: correlation between sampled ranks and true ranks
        true_ranks = np.argsort(np.argsort(-segment_returns))  # Higher return = lower rank (better)
        # Count concordant and discordant pairs
        concordant = 0
        discordant = 0
        for i in range(num_ranked_items):
            for j in range(i + 1, num_ranked_items):
                # True ordering
                true_order = np.sign(true_ranks[i] - true_ranks[j])
                # Sampled ordering
                sampled_order = np.sign(ranks[i] - ranks[j])
                if true_order == sampled_order:
                    concordant += 1
                else:
                    discordant += 1
        n_pairs = num_ranked_items * (num_ranked_items - 1) / 2
        tau = (concordant - discordant) / n_pairs if n_pairs > 0 else 0
        kendall_taus.append(tau)
    
    all_returns = np.array(all_returns)
    return_spreads = np.array(return_spreads)
    validity_ratios = np.array(validity_ratios)
    kendall_taus = np.array(kendall_taus)
    
    # 2. Return distribution
    print(f"\n[2] SEGMENT RETURN DISTRIBUTION")
    print(f"    Mean:   {np.mean(all_returns):.2f}")
    print(f"    Std:    {np.std(all_returns):.2f}")
    print(f"    Min:    {np.min(all_returns):.2f}")
    print(f"    Max:    {np.max(all_returns):.2f}")
    
    # 3. Return spread within rankings
    print(f"\n[3] RETURN SPREAD WITHIN RANKINGS (max - min)")
    print(f"    Mean spread: {np.mean(return_spreads):.2f}")
    print(f"    Std spread:  {np.std(return_spreads):.2f}")
    print(f"    Min spread:  {np.min(return_spreads):.2f}")
    print(f"    Max spread:  {np.max(return_spreads):.2f}")
    
    # 4. Ranking quality (Kendall tau)
    print(f"\n[4] RANKING QUALITY (Kendall tau with true ordering)")
    print(f"    Mean tau: {np.mean(kendall_taus):.4f}")
    print(f"    Std tau:  {np.std(kendall_taus):.4f}")
    print(f"    Min tau:  {np.min(kendall_taus):.4f}")
    print(f"    Max tau:  {np.max(kendall_taus):.4f}")
    perfect_rankings = np.sum(kendall_taus == 1.0)
    print(f"    Perfect rankings (tau=1): {perfect_rankings} ({perfect_rankings/len(kendall_taus)*100:.1f}%)")
    
    # 5. Segment validity
    print(f"\n[5] SEGMENT VALIDITY")
    print(f"    Mean validity ratio: {np.mean(validity_ratios):.4f}")
    print(f"    Min validity ratio:  {np.min(validity_ratios):.4f}")
    print(f"    Max validity ratio:  {np.max(validity_ratios):.4f}")
    
    low_validity = np.sum(validity_ratios < 0.5)
    print(f"    Segments with <50% valid steps: {low_validity} ({low_validity/len(validity_ratios)*100:.1f}%)")
    
    print("="*80)


class RankingDataset(Dataset):
    """
    Dataset for ranking learning with trajectory segments using the Plackett-Luce model.
    
    This is a generalization of preference learning (k=2) to rankings over k segments.
    Rankings are sampled stochastically from the Plackett-Luce distribution, where
    the probability of a ranking depends on the cumulative rewards of the segments.
    """
    def __init__(
        self, 
        num_episodes: int,  
        num_ranking_samples: int, 
        segment_len: int,
        num_ranked_items: int,
        policy: Expert,
        make_env_fn: Callable[[], gym.Env],
        device: str,
        base_seed: int,
        beta: float = 1.0,
        gamma: float = 0.99,
        obs_transform: Optional[Callable] = None,
        act_transform: Optional[Callable] = None,
        name: Optional[str] = "train",
        step_offset: int = 1,
        subsample_factor: int = 1,
        min_reward_threshold: Optional[float] = None,
        td_error_weight: float = 1.0
    ):
        """
        Initialize ranking dataset.
        
        Args:
            num_episodes: Number of episodes to collect from policy
            num_ranking_samples: Number of ranking samples to generate
            segment_len: Length of trajectory segments (None for full episodes)
            num_ranked_items: Number of items (segments) per ranking (k)
            policy: Expert policy to collect trajectories from
            make_env_fn: Factory function for creating environment
            device: Device to place tensors on
            base_seed: Base seed for episode collection
            beta: Rationality parameter for Plackett-Luce model (higher = more deterministic)
            gamma: Discount factor
            obs_transform: Optional observation transform
            act_transform: Optional action transform
            name: Dataset name for display
            step_offset: Offset for next_obs computation
            subsample_factor: Subsampling factor for trajectories
            min_reward_threshold: Optional minimum reward threshold for filtering
            td_error_weight: Weight for TD-error regularization
        """
        self.num_episodes = num_episodes
        self.num_ranking_samples = num_ranking_samples
        self.segment_len = segment_len
        self.num_ranked_items = num_ranked_items
        self.make_env_fn = make_env_fn
        self.device = device
        self.base_seed = base_seed
        self.obs_transform = obs_transform
        self.act_transform = act_transform
        self.name = name
        self.step_offset = step_offset
        self.subsample_factor = subsample_factor
        self.min_reward_threshold = min_reward_threshold
        self.beta = beta
        self.gamma = gamma
        self.td_error_weight = td_error_weight
        self.generator = random.Random(base_seed)

        # Generate rankings
        self.data = self.generate_rankings(policy=policy)

        # Scalar attributes
        self._rationality = torch.tensor(beta, dtype=torch.float32, device=device)
        self._gamma = torch.tensor(gamma, dtype=torch.float32, device=device)
        self._td_error_weight = torch.tensor(td_error_weight, dtype=torch.float32, device=device)
    
    def generate_rankings(self, policy: Expert) -> dict:
        """
        Generate segment groups and sample rankings from Plackett-Luce distribution.
        
        Returns:
            Dictionary of tensors with shape (num_samples, k, segment_len, ...)
            where k is num_ranked_items.
        """
        # Collect episodes to extract segments from
        episodes = prepare_episodes(
            policy=policy,
            num_episodes=self.num_episodes,
            make_env_fn=self.make_env_fn,
            base_seed=self.base_seed,
            step_offset=self.step_offset,
            subsample_factor=self.subsample_factor,
            obs_transform=self.obs_transform,
            act_transform=self.act_transform,
            min_reward_threshold=self.min_reward_threshold
        )

        print_ranking_stats(episodes, self.name)

        # Extract segments - we need k segments per ranking sample
        total_segments_needed = self.num_ranked_items * self.num_ranking_samples
        
        if self.segment_len is None:
            print(f"Using full episodes for ranking dataset. num_ranking_samples will be adjusted.")
            self.num_ranking_samples = len(episodes) // self.num_ranked_items
            total_segments_needed = self.num_ranked_items * self.num_ranking_samples
            segments = episodes[:total_segments_needed]
        else:
            segments = extract_segments_from_episodes(
                episodes, self.segment_len, total_segments_needed, rng=self.generator
            )

        # Group segments into ranking sets and sample rankings
        rankings = []
        for i in range(0, len(segments), self.num_ranked_items):
            segment_group = segments[i:i + self.num_ranked_items]
            
            if len(segment_group) < self.num_ranked_items:
                break  # Not enough segments for a complete ranking
            
            ranking_dict = {}
            
            # Stack segment data: shape (k, segment_len, ...)
            for key in segment_group[0]:
                ranking_dict[key] = np.stack([seg[key] for seg in segment_group], axis=0)
            
            # Compute returns for each segment (normalized by segment length)
            effective_segment_len = self.segment_len if self.segment_len is not None else max(
                len(seg[DataKey.REWS]) for seg in segment_group
            )
            returns = np.array([
                get_undiscounted_return(seg) / effective_segment_len 
                for seg in segment_group
            ])
            
            # Sample ranking from Plackett-Luce distribution
            ranks = sample_plackett_luce_ranking(returns, self.beta, self.generator)
            ranking_dict[DataKey.RANKING] = ranks
            
            rankings.append(ranking_dict)
        
        # Print detailed diagnostics
        print_ranking_diagnostics(rankings, self.beta, self.num_ranked_items, self.name)
    
        # Create contiguous tensors of shape (num_samples, k, segment_len, ...)
        tensors = {
            k: to_torch(np.stack([ranking[k] for ranking in rankings], axis=0), self.device) 
            for k in rankings[0].keys()
        }
        
        return tensors

    def __len__(self):
        return self.data[DataKey.RANKING].shape[0]
    
    def __getitem__(self, idx) -> dict:
        """
        Gets a single ranking feedback sample.

        Returns:
            Dictionary with ranking data including:
            - FEEDBACK_TYPE: FeedbackType.RANKING
            - RANKING: Rank assignments for each segment, shape (k,)
            - OBS, ACTS, REWS, etc.: Segment data, shape (k, segment_len, ...)
        """
        # Scalars
        item_dict = {
            DataKey.FEEDBACK_TYPE: FeedbackType.RANKING,
            DataKey.RATIONALITY: self._rationality,
            DataKey.GAMMA: self._gamma,
            DataKey.TD_ERROR_WEIGHT: self._td_error_weight,
        }

        # Add remaining fields from data
        for k in self.data.keys():
            item_dict[k] = self.data[k][idx]
        
        return item_dict
