import torch
import numpy as np
from torch.utils.data import Dataset
from typing import Callable, Optional
import gymnasium as gym
from umfavi.data.utils import prepare_episodes, extract_segments_from_episodes
from umfavi.utils.math import sigmoid
from umfavi.utils.gym import get_undiscounted_return
from umfavi.types import DataKey, FeedbackType, Trajectory
from umfavi.utils.policies import Expert
from umfavi.utils.torch_utils import to_torch
import random


def print_preference_stats(prefs: list[Trajectory], name: str) -> None:
    """Print statistics about the preference dataset.
    
    Args:
        preferences: Array of preference probabilities (probability that traj1 > traj2)
        cum_rews: Array of cumulative rewards for each pair, shape (num_samples, 2)
        name: Dataset name for display
        threshold: Distance from 0.5 to consider a preference "meaningful" (default 0.3 means <0.2 or >0.8)
    """
    print("-"*80)
    print(f"DATA SUMMARY Preferences: {name}")
    print("-"*80)

    cum_rews = np.array([np.nansum(p[DataKey.REWS]) for p in prefs])

    mean_reward = np.nanmean(cum_rews)
    std_reward = np.nanstd(cum_rews)
    min_reward = np.nanmin(cum_rews)
    max_reward = np.nanmax(cum_rews)

    print(f"Mean reward: {mean_reward:.1f} ± {std_reward:.1f} [{min_reward:.1f}, {max_reward:.1f}]")


def print_preference_pair_diagnostics(prefs: list[dict], beta: float, name: str) -> None:
    """Print detailed diagnostic statistics about preference pairs.
    
    Args:
        prefs: List of preference pair dictionaries containing rewards, preferences, and valid masks
        beta: Rationality parameter used for preference generation
        name: Dataset name for display
    """
    print("="*80)
    print(f"PREFERENCE PAIR DIAGNOSTICS: {name}")
    print("="*80)
    
    # Extract data from preference pairs
    preference_probs = []
    return_diffs = []
    return_pairs = []  # (r1, r2) tuples
    validity_ratios = []
    
    for pref_pair in prefs:
        # Get preference probability
        p = pref_pair[DataKey.PREFERENCE]
        preference_probs.append(p)
        
        # Get rewards for both trajectories (shape: 2, segment_len)
        rewards = pref_pair[DataKey.REWS]
        r1 = np.nansum(rewards[0])
        r2 = np.nansum(rewards[1])
        return_diffs.append(r1 - r2)
        return_pairs.append((r1, r2))
        
        # Get validity ratio
        valid = pref_pair[DataKey.VALID]
        valid_ratio = np.mean(valid)
        validity_ratios.append(valid_ratio)
    
    preference_probs = np.array(preference_probs)
    return_diffs = np.array(return_diffs)
    validity_ratios = np.array(validity_ratios)
    
    # 1. Preference probability distribution
    print("\n[1] PREFERENCE PROBABILITY DISTRIBUTION")
    print(f"    Beta (rationality): {beta}")
    print(f"    Mean p: {np.mean(preference_probs):.4f}")
    print(f"    Std p:  {np.std(preference_probs):.4f}")
    print(f"    Min p:  {np.min(preference_probs):.4f}")
    print(f"    Max p:  {np.max(preference_probs):.4f}")
    
    # Histogram bins
    bins = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
    hist, _ = np.histogram(preference_probs, bins=bins)
    print(f"\n    Preference probability histogram:")
    for i in range(len(bins)-1):
        bar = "█" * (hist[i] * 50 // max(max(hist), 1))
        print(f"    [{bins[i]:.1f}-{bins[i+1]:.1f}): {hist[i]:5d} {bar}")
    
    # Check for weak signal (concentrated around 0.5)
    weak_signal_ratio = np.mean((preference_probs > 0.4) & (preference_probs < 0.6))
    strong_signal_ratio = np.mean((preference_probs < 0.2) | (preference_probs > 0.8))
    print(f"\n    ⚠️  Weak signal (p in [0.4, 0.6]): {weak_signal_ratio*100:.1f}%")
    print(f"    ✓  Strong signal (p < 0.2 or p > 0.8): {strong_signal_ratio*100:.1f}%")
    
    # 2. Return differences (SUM-based, original)
    print("\n[2] RETURN DIFFERENCES - SUM (r1 - r2)")
    print(f"    Mean diff:  {np.mean(return_diffs):.2f}")
    print(f"    Std diff:   {np.std(return_diffs):.2f}")
    print(f"    Min diff:   {np.min(return_diffs):.2f}")
    print(f"    Max diff:   {np.max(return_diffs):.2f}")
    print(f"    Abs mean:   {np.mean(np.abs(return_diffs)):.2f}")
    
    # Logit analysis - SUM based (beta * diff)
    logits_sum = beta * return_diffs
    print(f"\n    Logits SUM-based (beta * sum_diff):")
    print(f"    Mean logit: {np.mean(logits_sum):.2f}")
    print(f"    Std logit:  {np.std(logits_sum):.2f}")
    print(f"    Min logit:  {np.min(logits_sum):.2f}")
    print(f"    Max logit:  {np.max(logits_sum):.2f}")
    
    # Check for saturation (sum-based)
    saturated_sum = np.sum(np.abs(logits_sum) > 10)
    print(f"    ⚠️  Saturated (|logit| > 10): {saturated_sum} ({saturated_sum/len(logits_sum)*100:.1f}%)")
    
    # 2b. MEAN-based return differences (normalized by valid steps)
    mean_return_diffs = []
    for pref_pair in prefs:
        rewards = pref_pair[DataKey.REWS]
        valid = pref_pair[DataKey.VALID]
        # Mean reward per trajectory
        r1_mean = np.nansum(rewards[0]) / max(np.sum(valid[0]), 1)
        r2_mean = np.nansum(rewards[1]) / max(np.sum(valid[1]), 1)
        mean_return_diffs.append(r1_mean - r2_mean)
    mean_return_diffs = np.array(mean_return_diffs)
    
    print(f"\n[2b] RETURN DIFFERENCES - MEAN (normalized by valid steps)")
    print(f"    Mean diff:  {np.mean(mean_return_diffs):.4f}")
    print(f"    Std diff:   {np.std(mean_return_diffs):.4f}")
    print(f"    Min diff:   {np.min(mean_return_diffs):.4f}")
    print(f"    Max diff:   {np.max(mean_return_diffs):.4f}")
    print(f"    Abs mean:   {np.mean(np.abs(mean_return_diffs)):.4f}")
    
    # Logit analysis - MEAN based (normalized)
    logits_mean = beta * mean_return_diffs
    print(f"\n    Logits MEAN-based (beta * mean_diff) [RECOMMENDED]:")
    print(f"    Mean logit: {np.mean(logits_mean):.4f}")
    print(f"    Std logit:  {np.std(logits_mean):.4f}")
    print(f"    Min logit:  {np.min(logits_mean):.4f}")
    print(f"    Max logit:  {np.max(logits_mean):.4f}")
    
    # Check for saturation (mean-based)
    saturated_mean = np.sum(np.abs(logits_mean) > 10)
    print(f"    ✓  Saturated (|logit| > 10): {saturated_mean} ({saturated_mean/len(logits_mean)*100:.1f}%)")
    
    # 3. Segment validity
    print("\n[3] SEGMENT VALIDITY")
    print(f"    Mean validity ratio: {np.mean(validity_ratios):.4f}")
    print(f"    Min validity ratio:  {np.min(validity_ratios):.4f}")
    print(f"    Max validity ratio:  {np.max(validity_ratios):.4f}")
    
    low_validity = np.sum(validity_ratios < 0.5)
    print(f"    ⚠️  Segments with <50% valid steps: {low_validity} ({low_validity/len(validity_ratios)*100:.1f}%)")
    
    # 4. Individual trajectory returns
    print("\n[4] TRAJECTORY RETURNS")
    r1_vals = np.array([p[0] for p in return_pairs])
    r2_vals = np.array([p[1] for p in return_pairs])
    print(f"    Traj 1: {np.mean(r1_vals):.2f} ± {np.std(r1_vals):.2f} [{np.min(r1_vals):.2f}, {np.max(r1_vals):.2f}]")
    print(f"    Traj 2: {np.mean(r2_vals):.2f} ± {np.std(r2_vals):.2f} [{np.min(r2_vals):.2f}, {np.max(r2_vals):.2f}]")
    
    print("="*80)

class PreferenceDataset(Dataset):
    """
    Dataset for preference learning with trajectory pairs and simulated preferences.
    """
    def __init__(
        self, 
        num_episodes: int,  
        num_pref_pairs: int, 
        segment_len: int,
        policy: Expert,
        make_env_fn: Callable[[], gym.Env],
        device: str,
        base_seed: int,  # base seed for episode collection (needs to be collision-free with other datasets)
        beta: float = 1.0,
        gamma: float = 0.99,
        obs_transform: Optional[Callable] = None,
        act_transform: Optional[Callable] = None,
        name: Optional[str] = "train",
        step_offset: int = 1,
        subsample_factor: int = 1,
        min_reward_threshold: Optional[float] = None,
        td_error_weight: float = 1.0
    ):
        """
        Initialize preference dataset.
        """

        self.num_episodes = num_episodes
        self.num_pref_pairs = num_pref_pairs
        self.segment_len = segment_len
        self.make_env_fn = make_env_fn
        self.device = device
        self.base_seed = base_seed
        self.obs_transform = obs_transform
        self.act_transform = act_transform
        self.name = name
        self.step_offset = step_offset
        self.subsample_factor = subsample_factor
        self.min_reward_threshold = min_reward_threshold
        self.beta = beta
        self.gamma = gamma
        self.td_error_weight = td_error_weight
        self.generator = random.Random(base_seed)

        # Generate trajectory pairs and preferences
        self.data = self.generate_preferences(policy=policy)

        # Scalar attributes
        self._rationality = torch.tensor(beta, dtype=torch.float32, device=device)
        self._gamma = torch.tensor(gamma, dtype=torch.float32, device=device)
        self._td_error_weight = torch.tensor(td_error_weight, dtype=torch.float32, device=device)
    
    def generate_preferences(self, policy: Expert) -> list[Trajectory]:
        """
        Generate trajectory pairs and preferences.
        
        Returns:
            Tuple of (trajectory_pairs, preferences) where:
            - trajectory_pairs: List of (traj1, traj2) pairs
            - preferences: List of preferences (0 for traj1, 1 for traj2)
        """
        # Collect episodes to extract segments from
        episodes = prepare_episodes(
            policy=policy,
            num_episodes=self.num_episodes,
            make_env_fn=self.make_env_fn,
            base_seed=self.base_seed,
            step_offset=self.step_offset,
            subsample_factor=self.subsample_factor,
            obs_transform=self.obs_transform,
            act_transform=self.act_transform,
            min_reward_threshold=self.min_reward_threshold
        )

        print_preference_stats(episodes, self.name)

        # Extract segments
        if self.segment_len is None:
            print(f"Using full episodes for preference dataset. n_pref_samples will be ignored (n_episodes is used instead)")
            self.num_pref_pairs = len(episodes) // 2
            segments = episodes
        else:
            segments = extract_segments_from_episodes(episodes, self.segment_len, 2*self.num_pref_pairs, rng=self.generator)

        prefs = []
        for i in range(0, len(segments), 2):
            seg1 = segments[i]
            seg2 = segments[i + 1]
            pref_pair = {}
            for k in seg1:
                pref_pair[k] = np.stack([seg1[k], seg2[k]], axis=0)
            # normalizing by number of valid steps to avoid logit saturation
            r1 = get_undiscounted_return(seg1) / self.segment_len
            r2 = get_undiscounted_return(seg2) / self.segment_len
            p = sigmoid(self.beta * (r1 - r2))
            pref = np.random.binomial(n=1, p=p)
            pref_pair[DataKey.PREFERENCE] = pref
            prefs.append(pref_pair)
        
        # Print detailed diagnostics for preference pairs
        print_preference_pair_diagnostics(prefs, self.beta, self.name)
    
        # Create contiguous tensors of shape (num_pairs, 2, segment_len, ...)
        tensors = {
            k: to_torch(np.stack([pref_pair[k] for pref_pair in prefs], axis=0), self.device) for k in prefs[0].keys()
        }
        return tensors

    def __len__(self):
        return self.data[DataKey.PREFERENCE].shape[0]
    
    def __getitem__(self, idx) -> dict:
        """
        Gets a single preference feedback sample.

        Returns:
            Dictionary with preference data.
        """
        # Scalars
        item_dict = {
            DataKey.FEEDBACK_TYPE: FeedbackType.PREFERENCE,
            DataKey.RATIONALITY: self._rationality,
            DataKey.GAMMA: self._gamma,
            DataKey.TD_ERROR_WEIGHT: self._td_error_weight,
        }

        # Add remaining fields from data
        for k in self.data.keys():
            item_dict[k] = self.data[k][idx]
        
        return item_dict