import numpy as np
from collections import deque


class RewardTrendMonitor:
    def __init__(self, window_size=100, decay_rate=0.1, trend_window=5, threshold=0.001, min_steps=10):
        """
        初始化奖励趋势监测器

        参数:
        - window_size: 存储EWMA值的窗口大小
        - decay_rate: 指数加权平均的衰减率
        - trend_window: 计算趋势的窗口大小
        - threshold: 趋势判断阈值
        - min_steps: 开始趋势判断前的最小步数
        """
        self.window_size = window_size
        self.decay_rate = decay_rate
        self.trend_window = trend_window
        self.threshold = threshold
        self.min_steps = min_steps

        # 存储EWMA值的双端队列
        self.ewma_values = deque(maxlen=window_size)
        # 当前EWMA值
        self.current_ewma = None
        # 步数计数器
        self.step_count = 0

    def update(self, accuracy_reward):
        """
        更新EWMA值并返回当前趋势

        参数:
        - accuracy_reward: 当前步骤的正确性奖励

        返回:
        - trend: 趋势描述 ('increasing', 'decreasing', 'stable', 'insufficient_data')
        """
        self.step_count += 1

        # 计算EWMA
        if self.current_ewma is None:
            self.current_ewma = accuracy_reward
        else:
            self.current_ewma = self.decay_rate * self.current_ewma + (1 - self.decay_rate) * accuracy_reward

        # 保存当前EWMA值
        self.ewma_values.append(self.current_ewma)

        # 如果步数不足，返回数据不足
        if self.step_count < self.min_steps:
            return 'insufficient_data'

        # 计算趋势
        return self._calculate_trend()

    def _calculate_trend(self):
        """计算当前趋势"""
        # 确保有足够的数据计算趋势
        if len(self.ewma_values) < self.trend_window:
            return 'stable'

        # 提取最近的EWMA值用于趋势计算
        recent_values = list(self.ewma_values)[-self.trend_window:]

        # 使用线性回归计算斜率
        x = np.arange(len(recent_values))
        y = np.array(recent_values)

        # 计算斜率
        slope = np.polyfit(x, y, 1)[0]

        # 根据斜率判断趋势
        if slope > self.threshold:
            return 'increasing'
        elif slope < -self.threshold:
            return 'decreasing'
        else:
            return 'stable'

    def should_apply_length_reward(self, accuracy_reward):
        """
        判断是否应该应用长度奖励

        参数:
        - accuracy_reward: 当前步骤的正确性奖励

        返回:
        - bool: 是否应用长度奖励
        """
        trend = self.update(accuracy_reward)

        # 只有在趋势稳定或上升时应用长度奖励
        return trend in ['increasing', 'stable', 'insufficient_data']


# 使用示例
if __name__ == "__main__":
    # 初始化趋势监测器
    trend_monitor = RewardTrendMonitor(
        window_size=100,
        decay_rate=0.9,
        trend_window=20,
        threshold=0.001,
        min_steps=50
    )

    # 模拟训练过程
    for step in range(200):
        # 模拟正确性奖励（后期下降）
        if step < 150:
            accuracy_reward = 0.8 + 0.1 * np.random.random()  # 前期较高
        else:
            accuracy_reward = 0.5 + 0.1 * np.random.random()  # 后期下降

        # 判断是否应该应用长度奖励
        apply_length_reward = trend_monitor.should_apply_length_reward(accuracy_reward)

        # 计算总奖励
        if apply_length_reward:
            # 假设长度奖励是固定的或根据回答长度计算
            length_reward = 0.1
            total_reward = accuracy_reward + length_reward
        else:
            total_reward = accuracy_reward

        print(f"Step {step}: Accuracy={accuracy_reward:.3f}, "
              f"EWMA={trend_monitor.current_ewma:.3f}, "
              f"Trend={trend_monitor._calculate_trend()}, "
              f"Apply Length Reward={apply_length_reward}, "
              f"Total Reward={total_reward:.3f}")