import time
from contextlib import contextmanager
from typing import Any

import ray
import torch
from omegaconf import OmegaConf

from oat.actor_rollout_worker import ActorRolloutWorker
from oat.actor_worker import ActorWorker
from oat.critic_worker import CriticWorker
from oat.learners.ppo import PPO
from oat.rm.batch import Batch
from oat.rm.reward_model import RewardModel
from oat.types import ActorRolloutRefOutput, CriticOutput, DistributedConfig
from oat.utils.data import zero_pad_sequences
from oat.utils.ipc import MISSING_BATCH, Wait, dump_traceback_on_error
from oat.utils.launcher import get_dist_config, set_seeds
from oat.utils.logging import get_logger, rank0_print
from oat.utils.ray_utils import (
    SCHEDULING_STRATEGY,
    as_remote_generator,
    is_rank_in_placement,
    rank0_placement,
)

logger = get_logger(__name__)


def reduce_metrics(metrics_list, prefix='', reduction='mean'):
    """
    Reduce a list of metrics dictionaries using the specified reduction method.

    Args:
        metrics_list: List of dictionaries containing metrics.
        prefix: Prefix to add to all metric keys.
        reduction: Reduction method ('mean' or 'sum').

    Returns:
        A single dictionary with reduced metrics.
    """
    if not metrics_list:
        return {}

    reduced_metrics = {}
    keys = metrics_list[0].keys()

    for key in keys:
        values = [m[key] for m in metrics_list if key in m]
        if not values:
            continue

        if reduction == 'mean':
            reduced_value = sum(values) / len(values)
        elif reduction == 'sum':
            reduced_value = sum(values)
        else:
            raise ValueError(f"Unknown reduction method: {reduction}")

        reduced_metrics[prefix + key] = reduced_value

    return reduced_metrics


@contextmanager
def _timer(name, timing_raw):
    start = time.perf_counter()
    try:
        yield
    finally:
        end = time.perf_counter()
        timing_raw[name] = end - start


def compute_timing_metrics(batch: Batch, timing_raw: dict[str, float]) -> dict[str, Any]:
    """
    Compute timing metrics for the training step.

    Args:
        batch: The batch of data.
        timing_raw: Dictionary containing raw timing measurements.

    Returns:
        Dictionary of timing metrics.
    """
    metrics = {}

    # Extract timing information
    total_time = sum(timing_raw.values())
    metrics['timing/total'] = total_time

    for key, value in timing_raw.items():
        metrics[f'timing/{key}'] = value
        if total_time > 0:
            metrics[f'timing/{key}_pct'] = value / total_time * 100

    # Add throughput metrics if available
    if 'generate' in timing_raw and timing_raw['generate'] > 0:
        batch_size = batch.batch['queries'].shape[0]
        seq_len = batch.batch['responses'].shape[1]
        metrics['throughput/samples_per_sec'] = batch_size / timing_raw['generate']
        metrics['throughput/tokens_per_sec'] = (batch_size * seq_len) / timing_raw['generate']

    return metrics


def compute_data_metrics(batch: Batch, use_critic: bool = True, compute_invalid_separately: bool = False) -> dict[str, Any]:
    """
    Compute data-related metrics from the batch.

    Args:
        batch: The batch of data.
        use_critic: Whether to include critic-related metrics.
        compute_invalid_separately: Whether to compute metrics separately for valid/invalid samples.

    Returns:
        Dictionary of data metrics.
    """
    metrics = {}

    # Extract basic information
    queries = batch.batch['queries']
    responses = batch.batch['responses']
    attention_mask = batch.batch['attention_mask']

    # Sequence lengths
    response_length = responses.shape[-1]
    response_mask = attention_mask[:, -response_length:]
    response_lengths = response_mask.sum(dim=-1)

    metrics['data/query_length'] = queries.shape[-1]
    metrics['data/response_length_mean'] = torch.mean(response_lengths.float()).detach().item()
    metrics['data/response_length_max'] = torch.max(response_lengths).detach().item()
    metrics['data/response_length_min'] = torch.min(response_lengths).detach().item()

    # Token-level scores
    if 'token_level_scores' in batch.batch:
        token_scores = batch.batch['token_level_scores']
        valid_token_scores = token_scores * response_mask
        metrics['data/token_scores_mean'] = (valid_token_scores.sum() / response_mask.sum()).detach().item()

    # Sequence-level rewards
    if 'token_level_rewards' in batch.batch:
        token_rewards = batch.batch['token_level_rewards']
        sequence_reward = (token_rewards * response_mask).sum(dim=-1)
        
        # ========== 全局统计（包含所有数据）==========
        metrics['critic/rewards/mean'] = torch.mean(sequence_reward).detach().item()
        metrics['critic/rewards/max'] = torch.max(sequence_reward).detach().item()
        metrics['critic/rewards/min'] = torch.min(sequence_reward).detach().item()
        metrics['critic/rewards/std'] = torch.std(sequence_reward).detach().item()

        # ========== 分离有效/无效数据统计 ==========
        if compute_invalid_separately and 'valid_for_training' in batch.batch:
            valid_mask = batch.batch['valid_for_training']
            
            # 有效数据的metrics
            if valid_mask.any():
                valid_rewards = sequence_reward[valid_mask]
                metrics['critic/rewards_valid/mean'] = torch.mean(valid_rewards).detach().item()
                metrics['critic/rewards_valid/max'] = torch.max(valid_rewards).detach().item()
                metrics['critic/rewards_valid/min'] = torch.min(valid_rewards).detach().item()
                metrics['critic/rewards_valid/std'] = torch.std(valid_rewards).detach().item()
                metrics['critic/rewards_valid/count'] = valid_mask.sum().item()
            
            # 无效数据的metrics（用于监控）
            invalid_mask = ~valid_mask
            if invalid_mask.any():
                invalid_rewards = sequence_reward[invalid_mask]
                metrics['critic/rewards_invalid/mean'] = torch.mean(invalid_rewards).detach().item()
                metrics['critic/rewards_invalid/max'] = torch.max(invalid_rewards).detach().item()
                metrics['critic/rewards_invalid/min'] = torch.min(invalid_rewards).detach().item()
                metrics['critic/rewards_invalid/count'] = invalid_mask.sum().item()
                
                # 添加无效样本比例
                total_samples = len(valid_mask)
                metrics['data_filtering/invalid_ratio'] = invalid_mask.sum().item() / total_samples
                metrics['data_filtering/valid_ratio'] = valid_mask.sum().item() / total_samples

    # KL divergence
    if 'kl' in batch.batch:
        kl = batch.batch['kl']
        valid_kl = kl * response_mask
        metrics['critic/kl_mean'] = (valid_kl.sum() / response_mask.sum()).detach().item()

    # Advantages (if using critic)
    if use_critic and 'advantages' in batch.batch:
        advantages = batch.batch['advantages']
        valid_advantages = advantages * response_mask
        metrics['critic/advantages_mean'] = (valid_advantages.sum() / response_mask.sum()).detach().item()

    # Values (if using critic)
    if use_critic and 'values' in batch.batch:
        values = batch.batch['values']
        valid_values = values * response_mask
        metrics['critic/values_mean'] = (valid_values.sum() / response_mask.sum()).detach().item()

    return metrics


def apply_kl_penalty(batch: Batch, kl_ctrl, kl_penalty: str = 'kl') -> tuple[Batch, dict]:
    """
    Apply KL penalty to the rewards.

    Args:
        batch: The batch of data.
        kl_ctrl: KL controller for adaptive penalty.
        kl_penalty: Type of KL penalty ('kl', 'abs', 'mse', 'full').

    Returns:
        Updated batch and metrics dictionary.
    """
    queries = batch.batch['queries']
    responses = batch.batch['responses']
    attention_mask = batch.batch['attention_mask']
    token_level_scores = batch.batch['token_level_scores']
    logprobs = batch.batch['logprobs']
    ref_logprobs = batch.batch['ref_logprobs']

    response_length = responses.shape[-1]
    response_mask = attention_mask[:, -response_length:]

    # Compute KL divergence
    kl = logprobs - ref_logprobs
    batch.batch['kl'] = kl

    # Compute non-score rewards (KL penalty)
    if kl_penalty == 'kl':
        non_score_rewards = -kl_ctrl.value * kl
    elif kl_penalty == 'abs':
        non_score_rewards = -kl_ctrl.value * torch.abs(kl)
    elif kl_penalty == 'mse':
        non_score_rewards = -kl_ctrl.value * torch.square(kl)
    elif kl_penalty == 'full':
        non_score_rewards = -kl_ctrl.value * (logprobs - ref_logprobs)
    else:
        raise ValueError(f"Unknown KL penalty type: {kl_penalty}")

    # Combine scores and KL penalty
    token_level_rewards = token_level_scores + non_score_rewards
    batch.batch['token_level_rewards'] = token_level_rewards

    # Update KL controller
    sequence_kl = (kl * response_mask).sum(dim=-1)
    kl_ctrl.update(sequence_kl.mean().detach().item(), n_steps=queries.shape[0])

    metrics = {
        'critic/kl_coef': kl_ctrl.value,
    }

    return batch, metrics


def compute_advantage(
    batch: Batch,
    adv_estimator: str = 'gae',
    gamma: float = 1.0,
    lam: float = 0.95,
    num_repeat: int = 1,
) -> Batch:
    """
    Compute advantages using the specified estimator.

    Args:
        batch: The batch of data.
        adv_estimator: Advantage estimator ('gae' or 'reinforce').
        gamma: Discount factor.
        lam: GAE lambda parameter.
        num_repeat: Number of times each sample is repeated.

    Returns:
        Updated batch with advantages.
    """
    queries = batch.batch['queries']
    responses = batch.batch['responses']
    attention_mask = batch.batch['attention_mask']
    token_level_rewards = batch.batch['token_level_rewards']

    response_length = responses.shape[-1]
    response_mask = attention_mask[:, -response_length:]

    if adv_estimator == 'gae':
        if 'values' not in batch.batch:
            raise ValueError("GAE estimator requires 'values' in batch")

        values = batch.batch['values']
        advantages = torch.zeros_like(token_level_rewards)
        lastgaelam = 0

        for t in reversed(range(response_length)):
            if t == response_length - 1:
                next_values = 0.0
            else:
                next_values = values[:, t + 1]

            delta = token_level_rewards[:, t] + gamma * next_values - values[:, t]
            advantages[:, t] = lastgaelam = delta + gamma * lam * lastgaelam

        batch.batch['advantages'] = advantages * response_mask

    elif adv_estimator == 'reinforce':
        # For REINFORCE, advantages are just the returns
        advantages = torch.zeros_like(token_level_rewards)
        for t in reversed(range(response_length)):
            if t == response_length - 1:
                advantages[:, t] = token_level_rewards[:, t]
            else:
                advantages[:, t] = token_level_rewards[:, t] + gamma * advantages[:, t + 1]

        batch.batch['advantages'] = advantages * response_mask

    else:
        raise ValueError(f"Unknown advantage estimator: {adv_estimator}")

    return batch


class RayPPOTrainer:
    """
    Ray-based PPO trainer that coordinates actor, critic, and rollout workers.
    """

    def __init__(self, config):
        self.config = config
        self.global_steps = 0

        # Set random seeds
        set_seeds(config.runtime.seed)

        # Get distributed configuration
        self.dist_config: DistributedConfig = get_dist_config(config)

        # Initialize reward model
        self.reward_fn = RewardModel(config.reward_model)

        # Initialize KL controller
        if hasattr(config.algorithm, 'kl_ctrl'):
            from oat.learners.kl_controller import KLController
            self.kl_ctrl = KLController(**OmegaConf.to_container(config.algorithm.kl_ctrl, resolve=True))
        else:
            self.kl_ctrl = None

        # Determine if we're using a critic
        self.use_critic = config.get('critic', None) is not None

        # Initialize Ray workers
        self._init_workers()

        # Initialize logger
        from oat.utils.logging import WandbLogger
        self.logger = WandbLogger(config, prefix='train')

    def _init_workers(self):
        """Initialize Ray workers for actor, critic, and rollouts."""
        config = self.config

        # Actor worker
        if is_rank_in_placement(self.dist_config.actor):
            ActorWorkerCls = ray.remote(
                num_cpus=self.dist_config.actor.num_cpus_per_worker,
                num_gpus=self.dist_config.actor.num_gpus_per_worker,
                scheduling_strategy=SCHEDULING_STRATEGY,
            )(ActorWorker)
            self.actor_wg = ActorWorkerCls.remote(config)
        else:
            self.actor_wg = None

        # Critic worker
        if self.use_critic and is_rank_in_placement(self.dist_config.critic):
            CriticWorkerCls = ray.remote(
                num_cpus=self.dist_config.critic.num_cpus_per_worker,
                num_gpus=self.dist_config.critic.num_gpus_per_worker,
                scheduling_strategy=SCHEDULING_STRATEGY,
            )(CriticWorker)
            self.critic_wg = CriticWorkerCls.remote(config)
        else:
            self.critic_wg = None

        # Actor-Rollout-Ref worker
        if is_rank_in_placement(self.dist_config.actor_rollout_ref):
            ActorRolloutWorkerCls = ray.remote(
                num_cpus=self.dist_config.actor_rollout_ref.num_cpus_per_worker,
                num_gpus=self.dist_config.actor_rollout_ref.num_gpus_per_worker,
                scheduling_strategy=SCHEDULING_STRATEGY,
            )(ActorRolloutWorker)
            self.actor_rollout_wg = ActorRolloutWorkerCls.remote(config)
        else:
            self.actor_rollout_wg = None

    @dump_traceback_on_error
    def fit(self):
        """Main training loop."""
        config = self.config
        logger = self.logger

        # Warmup critic if needed
        if self.use_critic and config.trainer.critic_warmup > 0:
            rank0_print(f"Warming up critic for {config.trainer.critic_warmup} steps...")
            for _ in range(config.trainer.critic_warmup):
                self._training_step(warmup=True)

        # Main training loop
        rank0_print("Starting main training loop...")
        for step in range(config.trainer.total_training_steps):
            self._training_step(warmup=False)

            # Save checkpoint
            if (step + 1) % config.trainer.save_interval == 0:
                self._save_checkpoint(step + 1)

        # Final save
        self._save_checkpoint(config.trainer.total_training_steps)
        logger.finish()

    def _training_step(self, warmup: bool = False):
        """Execute one training step."""
        timing_raw = {}
        metrics = {}

        # Generate rollouts
        with _timer('generate', timing_raw):
            batch = self._generate_batch()
            if batch is MISSING_BATCH:
                rank0_print("No batch generated, skipping step")
                return

        # Compute rewards and advantages
        with _timer('adv', timing_raw):
            # Compute token-level reward scores
            reward_tensor = self.reward_fn(batch)
            batch.batch['token_level_scores'] = reward_tensor
            
            # Compute sequence-level rewards for filtering
            response_length = batch.batch['responses'].shape[-1]
            response_mask = batch.batch['attention_mask'][:, -response_length:]
            sequence_rewards = (reward_tensor * response_mask).sum(dim=-1)
            
            # ========== 标记无效数据（但不立即过滤）==========
            valid_mask = ~torch.isin(sequence_rewards, torch.tensor([-4.0, -5.0], device=sequence_rewards.device))
            batch.batch['valid_for_training'] = valid_mask
            
            filtered_count = (~valid_mask).sum().item()
            if filtered_count > 0:
                rank0_print(f"[Step {self.global_steps}] Marking {filtered_count}/{len(valid_mask)} samples as invalid for training")
            # ==================================================
            
            # Apply KL penalty if not using KL loss
            if not self.config.actor_rollout_ref.actor.get('use_kl_loss', False):
                batch, kl_metrics = apply_kl_penalty(
                    batch, 
                    kl_ctrl=self.kl_ctrl, 
                    kl_penalty=self.config.algorithm.kl_penalty
                )
                metrics.update(kl_metrics)
            else:
                batch.batch['token_level_rewards'] = batch.batch['token_level_scores']
            
            # Compute advantages
            batch = compute_advantage(
                batch,
                adv_estimator=self.config.algorithm.adv_estimator,
                gamma=self.config.algorithm.gamma,
                lam=self.config.algorithm.lam,
                num_repeat=self.config.actor_rollout_ref.rollout.n,
            )

        # Update critic (only with valid samples)
        if self.use_critic:
            with _timer('update_critic', timing_raw):
                valid_indices = torch.where(batch.batch['valid_for_training'])[0]
                if len(valid_indices) > 0:
                    valid_batch = batch.select(valid_indices)
                    critic_output = self.critic_wg.update_critic.remote(valid_batch)
                    critic_output = ray.get(critic_output)
                    critic_output_metrics = reduce_metrics(critic_output.meta_info['metrics'])
                    metrics.update(critic_output_metrics)
                else:
                    rank0_print(f"[Step {self.global_steps}] No valid samples for critic update")

        # Update actor (only with valid samples, and only after warmup)
        if not warmup and self.config.trainer.critic_warmup <= self.global_steps:
            with _timer('update_actor', timing_raw):
                valid_indices = torch.where(batch.batch['valid_for_training'])[0]
                if len(valid_indices) > 0:
                    valid_batch = batch.select(valid_indices)
                    actor_output = self.actor_rollout_wg.update_actor.remote(valid_batch)
                    actor_output = ray.get(actor_output)
                    actor_output_metrics = reduce_metrics(actor_output.meta_info['metrics'])
                    metrics.update(actor_output_metrics)
                else:
                    rank0_print(f"[Step {self.global_steps}] No valid samples for actor update")

        # Sync weights
        with _timer('sync', timing_raw):
            self._sync_weights()

        # ========== Compute metrics (包含所有数据，包括无效样本) ==========
        metrics.update(compute_data_metrics(
            batch=batch, 
            use_critic=self.use_critic,
            compute_invalid_separately=True  # 分别计算有效/无效数据的统计
        ))
        metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))

        # Log metrics
        self.global_steps += 1
        metrics['trainer/global_steps'] = self.global_steps
        logger.log(data=metrics, step=self.global_steps)

        # Print progress
        if self.global_steps % 10 == 0:
            rank0_print(f"Step {self.global_steps}: " + 
                       f"reward_mean={metrics.get('critic/rewards/mean', 0):.3f}, " +
                       f"reward_valid_mean={metrics.get('critic/rewards_valid/mean', 0):.3f}, " +
                       f"invalid_ratio={metrics.get('data_filtering/invalid_ratio', 0):.3f}")

    def _generate_batch(self) -> Batch:
        """Generate a batch of rollouts."""
        # Generate from actor-rollout worker
        output_ref = self.actor_rollout_wg.generate_batch.remote()
        output: ActorRolloutRefOutput = ray.get(output_ref)

        if output is None or output.batch is None:
            return MISSING_BATCH

        return output.batch

    def _sync_weights(self):
        """Synchronize weights between workers."""
        # Get actor weights
        if self.actor_wg is not None:
            actor_weights_ref = self.actor_wg.get_weights.remote()
            actor_weights = ray.get(actor_weights_ref)
        else:
            actor_weights = None

        # Get critic weights
        if self.critic_wg is not None:
            critic_weights_ref = self.critic_wg.get_weights.remote()
            critic_weights = ray.get(critic_weights_ref)
        else:
            critic_weights = None

        # Update actor-rollout worker
        if self.actor_rollout_wg is not None and actor_weights is not None:
            self.actor_rollout_wg.update_weights.remote(actor_weights, critic_weights)

    def _save_checkpoint(self, step: int):
        """Save checkpoint at the given step."""
        rank0_print(f"Saving checkpoint at step {step}...")

        # Save actor
        if self.actor_wg is not None:
            save_path = f"{self.config.runtime.output_dir}/actor_step_{step}"
            ray.get(self.actor_wg.save.remote(save_path))
            rank0_print(f"Saved actor to {save_path}")

        # Save critic
        if self.critic_wg is not None:
            save_path = f"{self.config.runtime.output_dir}/critic_step_{step}"
            ray.get(self.critic_wg.save.remote(save_path))
            rank0_print(f"Saved critic to {save_path}")


def main():
    """Main entry point."""
    from oat.utils.launcher import create_config

    # Load configuration
    config = create_config()

    # Create trainer
    trainer = RayPPOTrainer(config)

    # Start training
    trainer.fit()


if __name__ == '__main__':
    main()