import random
import numpy as np
from collections import deque, namedtuple
import torch
from scipy.special import expit
from src.data.her_data_generator import create_history_treatment_goal_samples
Experience = namedtuple('Experience', ['history_dict', 'action', 'reward', 'next_history_dict', 'goal', 'done'])

class EmpiricalSampler:
    def __init__(self, outputs: np.ndarray):
        values, counts = np.unique(outputs, return_counts=True)
        self.values = values
        self.probs = counts / counts.sum()

    def sample(self, size=1) -> np.ndarray:
        return np.random.choice(self.values,
                                size=size,
                                replace=True,
                                p=self.probs)

def collect_her_samples(dataset_collection, replay_buffer, min_history_length=15, max_history_length=30, future_length=5):
    """
    使用HER策略从数据集中收集训练样本
    
    参数:
        dataset_collection: 包含train_f和val_f的数据集集合
        replay_buffer: HER回放缓冲区
        min_history_length: 最小历史长度
        max_history_length: 最大历史长度
        future_length: 未来步数
    """
    print("使用HER收集训练样本...")
    train_data = dataset_collection.train_f.data
    train_samples = create_history_treatment_goal_samples(
        train_data,
        min_history_length=min_history_length,
        max_history_length=max_history_length,
        future_length=future_length
    )
    
    print(f"从训练数据中创建了 {len(train_samples)} 个样本")
    episodes_added = 0
    for history_dict, future_dict, goal in train_samples:
        episode_transitions = []
        if 'current_treatments' not in future_dict:
            continue
            
        future_actions = future_dict['current_treatments']
        future_outputs = future_dict['outputs'] if 'outputs' in future_dict else future_dict['prev_outputs']
        
        steps = future_actions.shape[1]
        
        for t in range(steps):
            curr_history = {}
            for key in history_dict:
                if isinstance(history_dict[key], np.ndarray):
                    if len(history_dict[key].shape) >= 2:
                        if t > 0:
                            hist_data = np.concatenate([history_dict[key], future_dict[key][:, :t]], axis=1)
                        else:
                            hist_data = history_dict[key].copy()
                        curr_history[key] = hist_data
                    else:
                        curr_history[key] = history_dict[key]
                else:
                    curr_history[key] = history_dict[key]
            next_history = {}
            for key in history_dict:
                if isinstance(history_dict[key], np.ndarray):
                    if len(history_dict[key].shape) >= 2:
                        if t > 0:
                            next_hist_data = np.concatenate([history_dict[key], future_dict[key][:, :t+1]], axis=1)
                        else:
                            next_hist_data = np.concatenate([history_dict[key], future_dict[key][:, :1]], axis=1)
                        next_history[key] = next_hist_data
                    else:
                        next_history[key] = history_dict[key]
                else:
                    next_history[key] = history_dict[key]
            action = future_actions[0, t].copy()
            replay_buffer.add_to_episode(curr_history, action, next_history, goal)
        replay_buffer.process_episode()
        episodes_added += 1
        
        if episodes_added % 100 == 0:
            print(f"已处理 {episodes_added}/{len(train_samples)} 个episodes")
    
    print(f"完成HER样本收集，共添加 {episodes_added} 个episodes")
    print(f"回放缓冲区大小: {len(replay_buffer)}")
    
    return replay_buffer

class HERReplayBuffer:
    def __init__(self, capacity, k_future=4, reward_threshold=5e-3, reward_mode='combined'):
        """
        使用HER的经验回放缓冲区
        
        参数:
            capacity: 缓冲区容量
            k_future: 每个转换使用的未来目标数量
            reward_threshold: 达到目标的阈值
            reward_mode: 奖励计算模式，可选值:
                         'binary': 二元奖励 (0/-1)
                         'distance': 距离的负值
                         'exp': 指数衰减奖励
                         'progress': 基于进步的奖励
                         'step': 阶梯式奖励
                         'combined': 距离+目标达成奖励
        """
        self.buffer = deque(maxlen=capacity)
        self.episode_buffer = []  
        self.k_future = k_future
        self.reward_threshold = reward_threshold
        self.reward_mode = reward_mode
        self.previous_distances = {}  
        
    def add(self, history_dict, action, reward, next_history_dict, goal, done=False):
        """添加单步过渡"""
        self.buffer.append(Experience(history_dict, action, reward, next_history_dict, goal, done))
    
    def add_to_episode(self, history_dict, action, next_history_dict, goal):
        """将转换添加到当前episode缓冲区"""
        self.episode_buffer.append((history_dict, action, next_history_dict, goal))
    
    def _compute_reward(self, distance, done, prev_distance=None):
        """根据reward_mode计算奖励"""
        if self.reward_mode == 'binary':
            return 0.0 if done else -1.0
        
        elif self.reward_mode == 'distance':
            return -distance
        
        elif self.reward_mode == 'exp':
            return np.exp(-10 * distance) - 1  
        
        elif self.reward_mode == 'progress':
            if prev_distance is None:
                return -distance
            progress = prev_distance - distance  
            if progress > 0:
                return progress * 2.0
            else:
                return progress * 3.0
        
        elif self.reward_mode == 'step':
            if done:
                return 0.0
            elif distance < self.reward_threshold * 3:
                return -0.3
            elif distance < self.reward_threshold * 10:
                return -0.7
            else:
                return -1.0
        
        elif self.reward_mode == 'combined':
            distance_reward = -distance * 2.0
            goal_reward = 10.0 if done else 0.0
            return distance_reward + goal_reward

        elif self.reward_mode == 'sigmoid':
            return -expit(1e1 * (distance - self.reward_threshold))
        
        else:
            return 0.0 if done else -1.0
    
    def process_episode(self):
        """处理并存储当前episode，应用HER策略"""

        if len(self.episode_buffer) == 0:
            return
        step_distances = {}
        for t, (history_dict, action, next_history_dict, goal) in enumerate(self.episode_buffer):
            current_output = self._get_current_output(next_history_dict)
            distance = np.linalg.norm(current_output - goal)
            done = distance < self.reward_threshold
            prev_distance = step_distances.get((t-1, tuple(goal.flatten())), None)
            reward = self._compute_reward(distance, done, prev_distance)
            step_distances[(t, tuple(goal.flatten()))] = distance
            self.add(history_dict, action, reward, next_history_dict, goal, done)
            future_idx_range = range(t+1, len(self.episode_buffer))
            future_idxs = random.sample(future_idx_range, min(self.k_future, len(future_idx_range))) if future_idx_range else []
            
            for future_idx in future_idxs:
                future_transition = self.episode_buffer[future_idx]
                future_next_history_dict = future_transition[0]
                new_goal = self._get_current_output(future_next_history_dict)
                new_distance = np.linalg.norm(current_output - new_goal)
                new_done = new_distance < self.reward_threshold
                new_goal_key = tuple(new_goal.flatten())
                prev_new_distance = step_distances.get((t-1, new_goal_key), None)
                new_reward = self._compute_reward(new_distance, new_done, prev_new_distance)
                step_distances[(t, new_goal_key)] = new_distance
                self.add(history_dict, action, new_reward, next_history_dict, new_goal, new_done)
        self.episode_buffer = []
    
    def _get_current_output(self, history_dict):
        """从历史字典中获取当前输出"""
        if 'outputs' in history_dict:
            return history_dict['outputs'][0, -1].copy()
        elif 'prev_outputs' in history_dict:
            return history_dict['prev_outputs'][0, -1].copy()
        else:
            raise ValueError("历史字典中未找到输出特征")
    
    def sample(self, batch_size):
        """采样一批转换"""
        if len(self.buffer) < batch_size:
            return None
            
        batch = random.sample(self.buffer, batch_size)
        return batch
    
    def __len__(self):
        return len(self.buffer)
    
    def clear(self):
        self.buffer.clear()
        self.episode_buffer = []

def search_reward_threshold_adaptive(dataset_collection,
                                     min_history_length=15, max_history_length=30, future_length=5,
                                     capacity=10000, k_future=4, reward_mode='binary',
                                     initial_threshold=1e-3, target_hit_ratio=0.1,
                                     max_iter=10, hit_tolerance=0.01,
                                     max_high_limit=1.0, min_low_limit=1e-6):
    """
    自适应搜索 reward_threshold：
    - 先用 initial_threshold 估计命中率
    - 动态扩大 high（右边界），直到命中率不低于目标或超限
    - 动态缩小 low（左边界），直到命中率不高于目标或低限
    - 最后用二分法在区间 [low, high] 搜索最优阈值。
    """
    replay_buffer = HERReplayBuffer(capacity=capacity, k_future=k_future,
                                   reward_threshold=initial_threshold, reward_mode=reward_mode)
    collect_her_samples(dataset_collection, replay_buffer,
                        min_history_length=min_history_length,
                        max_history_length=max_history_length,
                        future_length=future_length)

    hit = sum(1 for item in replay_buffer.buffer if item.reward == 0)
    total = len(replay_buffer.buffer)
    hit_ratio = hit / total if total > 0 else 0

    print(f"Initial threshold={initial_threshold}, hit/all={hit_ratio:.4f}, total={total}")
    low = initial_threshold
    high = initial_threshold * 10
    while hit_ratio < target_hit_ratio and high < max_high_limit:
        replay_buffer = HERReplayBuffer(capacity=capacity, k_future=k_future,
                                       reward_threshold=high, reward_mode=reward_mode)
        collect_her_samples(dataset_collection, replay_buffer,
                            min_history_length=min_history_length,
                            max_history_length=max_history_length,
                            future_length=future_length)

        hit = sum(1 for item in replay_buffer.buffer if item.reward == 0)
        total = len(replay_buffer.buffer)
        hit_ratio = hit / total if total > 0 else 0

        print(f"Expand high: threshold={high}, hit/all={hit_ratio:.4f}, total={total}")

        if hit_ratio >= target_hit_ratio:
            break

        high = min(high * 10, max_high_limit)
    if hit_ratio > target_hit_ratio:
        low = initial_threshold / 10
        high = initial_threshold
        while low > min_low_limit and hit_ratio > target_hit_ratio:
            replay_buffer = HERReplayBuffer(capacity=capacity, k_future=k_future,
                                           reward_threshold=low, reward_mode=reward_mode)
            collect_her_samples(dataset_collection, replay_buffer,
                                min_history_length=min_history_length,
                                max_history_length=max_history_length,
                                future_length=future_length)

            hit = sum(1 for item in replay_buffer.buffer if item.reward == 0)
            total = len(replay_buffer.buffer)
            hit_ratio = hit / total if total > 0 else 0

            print(f"Shrink low: threshold={low}, hit/all={hit_ratio:.4f}, total={total}")

            if hit_ratio <= target_hit_ratio:
                break

            low /= 10
    for i in range(max_iter):
        mid = (low + high) / 2

        replay_buffer = HERReplayBuffer(capacity=capacity, k_future=k_future,
                                       reward_threshold=mid, reward_mode=reward_mode)
        collect_her_samples(dataset_collection, replay_buffer,
                            min_history_length=min_history_length,
                            max_history_length=max_history_length,
                            future_length=future_length)

        hit = sum(1 for item in replay_buffer.buffer if item.reward == 0)
        total = len(replay_buffer.buffer)
        hit_ratio = hit / total if total > 0 else 0

        print(f"Iteration {i+1}: threshold={mid:.6f}, hit/all={hit_ratio:.4f}, total={total}")

        if abs(hit_ratio - target_hit_ratio) < hit_tolerance:
            print(f"Converged at iteration {i+1} with hit_ratio={hit_ratio:.4f}")
            return mid

        if hit_ratio < target_hit_ratio:
            low = mid
        else:
            high = mid

    return (low + high) / 2

