import torch
import numpy as np
from torch.utils.data import Dataset
from typing import Callable, Optional
import gymnasium as gym
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, PillowWriter
from umfavi.data.utils import prepare_episodes, extract_segments_from_episodes
from umfavi.types import DataKey, FeedbackType, Trajectory
from umfavi.utils.policies import Expert, QValueModel, TabularQValueModel
from umfavi.utils.torch_utils import to_torch
import random
from umfavi.envs.env_types import TabularEnv


def print_stop_stats(episodes: list[Trajectory], name: str) -> None:
    """Print statistics about the stop dataset episodes."""
    print("-" * 80)
    print(f"DATA SUMMARY Stops: {name}")
    print("-" * 80)
    
    cum_rews = np.array([np.nansum(ep[DataKey.REWS]) for ep in episodes])
    lengths = np.array([len(ep[DataKey.REWS]) for ep in episodes])
    
    print(f"Episodes: {len(episodes)}")
    print(f"Return: {np.mean(cum_rews):.1f} ± {np.std(cum_rews):.1f} [{np.min(cum_rews):.1f}, {np.max(cum_rews):.1f}]")
    print(f"Length: {np.mean(lengths):.1f} ± {np.std(lengths):.1f} [{np.min(lengths)}, {np.max(lengths)}]")


def print_stop_diagnostics(
    stop_times: np.ndarray,
    segment_len: int,
    max_regrets: np.ndarray,
    lambd: float,
    c: float,
    regret_ref: float,
    regret_discount: float,
    name: str
) -> None:
    """Print diagnostic statistics about stop times."""
    print("=" * 80)
    print(f"STOP DIAGNOSTICS: {name}")
    print("=" * 80)
    
    # Regret distribution
    print(f"\n[1] MAX CUMULATIVE REGRET (per segment)")
    print(f"    Mean:   {np.mean(max_regrets):.2f}")
    print(f"    Std:    {np.std(max_regrets):.2f}")
    print(f"    25th:   {np.percentile(max_regrets, 25):.2f}")
    print(f"    50th:   {np.percentile(max_regrets, 50):.2f}")
    print(f"    75th:   {np.percentile(max_regrets, 75):.2f}")
    print(f"    Max:    {np.max(max_regrets):.2f}")
    
    # Lambda calibration
    print(f"\n[2] LAMBDA CALIBRATION")
    print(f"    c (scale):          {c}")
    print(f"    regret_ref:         {regret_ref:.2f}")
    print(f"    regret_discount:    {regret_discount:.2f}")
    print(f"    lambda:             {lambd:.6f}")
    
    # Stop time distribution
    valid_stops = stop_times[stop_times >= 0]
    censored = np.sum(stop_times < 0)
    
    print(f"\n[3] STOP TIME DISTRIBUTION (segment_len={segment_len})")
    print(f"    Stopped:    {len(valid_stops)} ({len(valid_stops)/len(stop_times)*100:.1f}%)")
    print(f"    Censored:   {censored} ({censored/len(stop_times)*100:.1f}%)")
    
    if len(valid_stops) > 0:
        print(f"    Mean stop:  {np.mean(valid_stops):.1f}")
        print(f"    Std stop:   {np.std(valid_stops):.1f}")
        print(f"    Min stop:   {np.min(valid_stops)}")
        print(f"    Max stop:   {np.max(valid_stops)}")
        
        # Relative stop position (stop_time / segment_len)
        rel_stops = valid_stops / segment_len
        print(f"    Relative (t/T): {np.mean(rel_stops):.2f} ± {np.std(rel_stops):.2f}")
    
    print("=" * 80)


class StopDataset(Dataset):
    """
    Dataset for stop feedback learning using segments.
    
    Extracts random segments from episodes and simulates human termination
    behavior based on cumulative regret within each segment.
    """
    
    def __init__(
        self,
        num_episodes: int,
        num_samples: int,
        segment_len: int,
        q_model: QValueModel,
        policy: Expert,
        make_env_fn: Callable[[], gym.Env],
        device: str,
        base_seed: int,
        c: float = 1.0,
        regret_percentile: float = 75.0,
        regret_discount: float = 0.9,
        gamma: float = 0.99,
        obs_transform: Optional[Callable] = None,
        act_transform: Optional[Callable] = None,
        name: Optional[str] = "train",
        step_offset: int = 1,
        subsample_factor: int = 1,
        min_reward_threshold: Optional[float] = None,
        td_error_weight: float = 1.0,
    ):
        """
        Initialize stop dataset.
        
        Args:
            num_episodes: Number of episodes to collect for segment extraction
            num_samples: Number of segments to extract
            segment_len: Length of each segment
            q_model: Q-value model for computing regret
            policy: Policy to generate trajectories (should be suboptimal for diversity)
            make_env_fn: Factory function for creating environment
            device: Device to place tensors on
            base_seed: Base seed for reproducibility
            c: Calibration constant for lambda (higher = more aggressive stopping)
            regret_percentile: Percentile of final regrets to use as reference
            regret_discount: Discount factor for old regret (0-1). Lower = faster forgetting.
            gamma: Discount factor for TD-error
        """
        self.num_episodes = num_episodes
        self.num_samples = num_samples
        self.segment_len = segment_len
        self.q_model = q_model
        self.make_env_fn = make_env_fn
        self.device = device
        self.base_seed = base_seed
        self.c = c
        self.regret_percentile = regret_percentile
        self.regret_discount = regret_discount
        self.gamma = gamma
        self.obs_transform = obs_transform
        self.act_transform = act_transform
        self.name = name
        self.step_offset = step_offset
        self.subsample_factor = subsample_factor
        self.min_reward_threshold = min_reward_threshold
        self.td_error_weight = td_error_weight
        self.generator = random.Random(base_seed)
        
        # Generate stop data
        self.data, self.lambd = self.generate_stops(policy=policy)
        
        # Scalar attributes
        self._lambda = torch.tensor(self.lambd, dtype=torch.float32, device=device)
        self._regret_discount = torch.tensor(regret_discount, dtype=torch.float32, device=device)
        self._gamma = torch.tensor(gamma, dtype=torch.float32, device=device)
        self._td_error_weight = torch.tensor(td_error_weight, dtype=torch.float32, device=device)
    
    def compute_segment_regrets(self, segment: Trajectory) -> np.ndarray:
        """Compute discounted cumulative regret for a segment.
        
        Uses the recursive formula: R_t = regret_discount * R_{t-1} + r_t
        Regret starts at 0 at the beginning of each segment.
        """
        if isinstance(self.q_model, TabularQValueModel):
            obs = segment[DataKey.STATES]
        else:
            obs = segment[DataKey.OBS]
        acts = segment[DataKey.ACTS]
        valid = segment[DataKey.VALID]
        T = len(obs)
        
        # Compute instantaneous regrets
        instant_regrets = np.zeros(T)
        for t in range(T):
            if not valid[t]:
                continue
            q_values = self.q_model.q_values(obs[t]).squeeze()
            a_star = np.argmax(q_values)
            a = int(np.asarray(acts[t]).item())
            instant_regret = q_values[a_star] - q_values[a]
            instant_regrets[t] = max(0.0, instant_regret)
        
        # Compute discounted cumulative regret: R_t = discount * R_{t-1} + r_t
        discounted_cumsum = np.zeros(T)
        discounted_cumsum[0] = instant_regrets[0]
        for t in range(1, T):
            discounted_cumsum[t] = self.regret_discount * discounted_cumsum[t-1] + instant_regrets[t]
        
        return discounted_cumsum
    
    def calibrate_lambda(self, cumsum_regrets: list[np.ndarray]) -> tuple[float, float]:
        """Calibrate lambda based on max regret distribution across segments."""
        # Use the max regret within each segment (not final, since regret can decay)
        max_regrets = np.array([cr.max() for cr in cumsum_regrets])
        regret_ref = np.percentile(max_regrets, self.regret_percentile)
        
        # Avoid division by zero
        if regret_ref < 1e-8:
            regret_ref = 1.0
        
        lambd = self.c / regret_ref
        return lambd, regret_ref
    
    def simulate_stop(self, cum_regret: np.ndarray, valid: np.ndarray, lambd: float, rng: np.random.Generator) -> int:
        """Simulate stop time for a single segment using the hazard model.
        
        Returns:
            Stop time (0 to segment_len-1) or -1 if no stop (censored)
        """
        T = len(cum_regret)
        
        for t in range(T):
            if not valid[t]:
                continue
            R_t = cum_regret[t]
            h_t = 1.0 - np.exp(-lambd * R_t)
            
            # Sample stop event
            if rng.random() < h_t:
                return t
        
        return -1  # Censored (no stop occurred)
    
    def generate_stops(self, policy: Expert) -> tuple[dict, float]:
        """Generate segments and simulate stop times."""
        # Collect episodes
        episodes = prepare_episodes(
            policy=policy,
            num_episodes=self.num_episodes,
            make_env_fn=self.make_env_fn,
            base_seed=self.base_seed,
            step_offset=self.step_offset,
            subsample_factor=self.subsample_factor,
            obs_transform=self.obs_transform,
            act_transform=self.act_transform,
            min_reward_threshold=self.min_reward_threshold,
        )
        
        print_stop_stats(episodes, self.name)
        
        # Extract random segments
        segments = extract_segments_from_episodes(
            episodes, self.segment_len, self.num_samples, rng=self.generator
        )
        
        # Compute cumulative regrets for each segment
        cumsum_regrets = [self.compute_segment_regrets(seg) for seg in segments]
        
        # Calibrate lambda based on segment regrets
        lambd, regret_ref = self.calibrate_lambda(cumsum_regrets)
        
        # Simulate stop times for each segment
        np_rng = np.random.default_rng(self.base_seed)
        stop_times = []
        for seg, cum_regret in zip(segments, cumsum_regrets):
            stop_t = self.simulate_stop(cum_regret, seg[DataKey.VALID], lambd, np_rng)
            stop_times.append(stop_t)
            seg[DataKey.STOP_TIME] = stop_t
        
        stop_times = np.array(stop_times)
        
        # Print diagnostics
        max_regrets = np.array([cr.max() for cr in cumsum_regrets])
        print_stop_diagnostics(
            stop_times, self.segment_len, max_regrets,
            lambd, self.c, regret_ref, self.regret_discount, self.name
        )
        
        # Create contiguous tensors of shape (num_samples, segment_len, ...)
        tensors = {
            k: to_torch(np.stack([seg[k] for seg in segments], axis=0), self.device)
            for k in segments[0].keys() if k != DataKey.STOP_TIME
        }
        # Stop times are scalars per segment
        tensors[DataKey.STOP_TIME] = to_torch(stop_times, self.device)
        
        return tensors, lambd
    
    def __len__(self):
        return self.data[DataKey.STOP_TIME].shape[0]
    
    def __getitem__(self, idx) -> dict:
        """Get a single stop feedback sample (segment with stop time)."""
        item_dict = {
            DataKey.FEEDBACK_TYPE: FeedbackType.STOP,
            DataKey.LAMBDA: self._lambda,
            DataKey.REGRET_DISCOUNT: self._regret_discount,
            DataKey.GAMMA: self._gamma,
            DataKey.TD_ERROR_WEIGHT: self._td_error_weight,
        }
        
        for k in self.data.keys():
            item_dict[k] = self.data[k][idx]
        
        return item_dict
