import random
import numpy as np
from collections import deque, namedtuple
import torch
from scipy.special import expit
from src.data.her_data_generator import create_history_treatment_goal_samples

#Experience tuple definition
Experience = namedtuple('Experience', ['history_dict', 'action', 'reward', 'next_history_dict', 'goal', 'done'])

class EmpiricalSampler:
    def __init__(self, outputs: np.ndarray):
        values, counts = np.unique(outputs, return_counts=True)
        self.values = values
        self.probs = counts / counts.sum()

    def sample(self, size=1) -> np.ndarray:
        return np.random.choice(self.values,
                                size=size,
                                replace=True,
                                p=self.probs)

def collect_her_samples(dataset_collection, replay_buffer, min_history_length=15, max_history_length=30, future_length=5):
    """
    Use her strategy to collect training samples from the dataset
    
    Args:
        dataset_collection: A collection of data sets containing train_f and val_f
        replay_buffer: her playback buffer
        min_history_length: Minimum history length
        max_history_length: maximum history length
        future_length: number of future steps
    """
    print("Use her to collect training samples...")
    
    #Get training data
    train_data = dataset_collection.train_f.data
    
    #Create samples using create_history_treatment_goal_samples
    train_samples = create_history_treatment_goal_samples(
        train_data,
        min_history_length=min_history_length,
        max_history_length=max_history_length,
        future_length=future_length
    )
    
    print(f"{len (train_samples)} samples created from training data")
    
    #Process each sample and apply the her strategy
    episodes_added = 0
    for history_dict, future_dict, goal in train_samples:
        #Create a transformation for an episode
        episode_transitions = []
        
        #Get historical and future sequences
        if 'current_treatments' not in future_dict:
            continue
            
        future_actions = future_dict['current_treatments']
        future_outputs = future_dict['outputs'] if 'outputs' in future_dict else future_dict['prev_outputs']
        
        steps = future_actions.shape[1]
        
        for t in range(steps):
            #Current History
            curr_history = {}
            for key in history_dict:
                if isinstance(history_dict[key], np.ndarray):
                    #Execute join only on 3D arrays
                    if len(history_dict[key].shape) >= 2:
                        #Add history + future before current timestep
                        if t > 0:
                            hist_data = np.concatenate([history_dict[key], future_dict[key][:, :t]], axis=1)
                        else:
                            hist_data = history_dict[key].copy()
                        curr_history[key] = hist_data
                    else:
                        #For arrays of other dimensions, copy only the original data
                        curr_history[key] = history_dict[key]
                else:
                    curr_history[key] = history_dict[key]
                        
            #Next History
            next_history = {}
            for key in history_dict:
                if isinstance(history_dict[key], np.ndarray):
                    #Execute join only on 3D arrays
                    if len(history_dict[key].shape) >= 2:
                        #Add history + future of current timestep
                        if t > 0:
                            next_hist_data = np.concatenate([history_dict[key], future_dict[key][:, :t+1]], axis=1)
                        else:
                            next_hist_data = np.concatenate([history_dict[key], future_dict[key][:, :1]], axis=1)
                        next_history[key] = next_hist_data
                    else:
                        #For arrays of other dimensions, copy only the original data
                        next_history[key] = history_dict[key]
                else:
                    next_history[key] = history_dict[key]
            
            #Current Action
            action = future_actions[0, t].copy()
            
            #Episode buffer added to playback buffer
            replay_buffer.add_to_episode(curr_history, action, next_history, goal)
        
        #Process current episode, apply her
        replay_buffer.process_episode()
        episodes_added += 1
        
        if episodes_added % 100 == 0:
            print(f"{episodes_added} of {len (train_samples)} episodes processed")
    
    print(f"Complete her sample collection, add {episodes_added} episodes in total")
    print(f"Playback buffer size: {len (replay_buffer)}")
    
    return replay_buffer

class HERReplayBuffer:
    def __init__(self, capacity, k_future=4, reward_threshold=5e-3, reward_mode='combined'):
        """
        Experience with her Playback Buffer
        
        Args:
            capacity: Buffer capacity
            k_future: number of future goals used per conversion
            reward_threshold: the threshold to reach the goal
            reward_mode: reward calculation mode, optional value:
                         'binary': Binary Reward (0/-1)
                         'distance': negative value for distance
                         'exp': exponential attenuation reward
                         'progress': progress-based rewards
                         'step': tiered rewards
                         'combined': Distance + Goal Completion Reward
        """
        self.buffer = deque(maxlen=capacity)
        self.episode_buffer = []  #Temporarily store transitions for a single episode
        self.k_future = k_future
        self.reward_threshold = reward_threshold
        self.reward_mode = reward_mode
        self.previous_distances = {}  #Used in progress mode
        
    def add(self, history_dict, action, reward, next_history_dict, goal, done=False):
        """Add one-step transition"""
        self.buffer.append(Experience(history_dict, action, reward, next_history_dict, goal, done))
    
    def add_to_episode(self, history_dict, action, next_history_dict, goal):
        """Add a transformation to the current episode buffer"""
        self.episode_buffer.append((history_dict, action, next_history_dict, goal))
    
    def _compute_reward(self, distance, done, prev_distance=None):
        """Calculate rewards based on reward_mode"""
        if self.reward_mode == 'binary':
            # return 1.0 if done else .0
            return 0.0 if done else -1.0
        
        elif self.reward_mode == 'distance':
            return -distance
        
        elif self.reward_mode == 'exp':
            return np.exp(-10 * distance) - 1  #Factor 10 makes rewards more sensitive
        
        elif self.reward_mode == 'progress':
            if prev_distance is None:
                return -distance
            progress = prev_distance - distance  #Positive values indicate progress
            #Extra incentive for progress, punishment for retrogression
            if progress > 0:
                return progress * 2.0
            else:
                return progress * 3.0
        
        elif self.reward_mode == 'step':
            if done:
                return 0.0
            elif distance < self.reward_threshold * 3:
                return -0.3
            elif distance < self.reward_threshold * 10:
                return -0.7
            else:
                return -1.0
        
        elif self.reward_mode == 'combined':
            #Combo Rewards: Distance-based Consecutive Rewards + Goal Achievement Rewards
            distance_reward = -distance * 2.0
            goal_reward = 10.0 if done else 0.0
            return distance_reward + goal_reward

        elif self.reward_mode == 'sigmoid':
            return -expit(1e1 * (distance - self.reward_threshold))
        
        else:
            #Default Binary Reward
            return 0.0 if done else -1.0
    
    def process_episode(self):
        """Process and store the current episode, apply her policy"""
        # sampler = EmpiricalSampler(data['outputs'])

        if len(self.episode_buffer) == 0:
            return
        #Record distance for each step (for progression rewards)
        step_distances = {}
        
        #First add the transformation using the original target
        for t, (history_dict, action, next_history_dict, goal) in enumerate(self.episode_buffer):
            #Calculate distance to original target
            current_output = self._get_current_output(next_history_dict)
            distance = np.linalg.norm(current_output - goal)
            done = distance < self.reward_threshold
            
            #Calculate Rewards
            prev_distance = step_distances.get((t-1, tuple(goal.flatten())), None)
            reward = self._compute_reward(distance, done, prev_distance)
            
            #Record current distance
            step_distances[(t, tuple(goal.flatten()))] = distance
            
            #Add to playback buffer
            self.add(history_dict, action, reward, next_history_dict, goal, done)
            
            #For each transition, use k future states as new targets
            future_idx_range = range(t+1, len(self.episode_buffer))
            future_idxs = random.sample(future_idx_range, min(self.k_future, len(future_idx_range))) if future_idx_range else []
            
            for future_idx in future_idxs:
                future_transition = self.episode_buffer[future_idx]
                future_next_history_dict = future_transition[0]
                
                #Use future output as new target
                new_goal = self._get_current_output(future_next_history_dict)

                # new_goal = np.random.uniform(low=-5.0, high=5.0, size=1)

                # new_goal = sampler.sample()
                
                #Calculate distance from new target
                new_distance = np.linalg.norm(current_output - new_goal)
                new_done = new_distance < self.reward_threshold
                
                #Calculate new rewards
                new_goal_key = tuple(new_goal.flatten())
                prev_new_distance = step_distances.get((t-1, new_goal_key), None)
                new_reward = self._compute_reward(new_distance, new_done, prev_new_distance)
                
                #Record current distance
                step_distances[(t, new_goal_key)] = new_distance
                
                #Add to playback buffer
                self.add(history_dict, action, new_reward, next_history_dict, new_goal, new_done)
        
        #Empty episode buffer
        self.episode_buffer = []
    
    def _get_current_output(self, history_dict):
        """Get current output from historical dictionary"""
        if 'outputs' in history_dict:
            return history_dict['outputs'][0, -1].copy()
        elif 'prev_outputs' in history_dict:
            return history_dict['prev_outputs'][0, -1].copy()
        else:
            raise ValueError("Output feature not found in historical dictionary")
    
    def sample(self, batch_size):
        """Sample Batch Conversions"""
        if len(self.buffer) < batch_size:
            return None
            
        batch = random.sample(self.buffer, batch_size)
        return batch
    
    def __len__(self):
        return len(self.buffer)
    
    def clear(self):
        self.buffer.clear()
        self.episode_buffer = []

def search_reward_threshold_adaptive(dataset_collection,
                                     min_history_length=15, max_history_length=30, future_length=5,
                                     capacity=10000, k_future=4, reward_mode='binary',
                                     initial_threshold=1e-3, target_hit_ratio=0.1,
                                     max_iter=10, hit_tolerance=0.01,
                                     max_high_limit=1.0, min_low_limit=1e-6):
    """
    Adaptive search reward_threshold:
    - Estimate hit rate with initial_threshold first
    - Dynamically expand high (right boundary) until hit rate is not below target or overrun
    - Dynamically shrink low (left boundary) until hit rate is no higher than target or low limit
    - Finally, use the dichotomy to search for the optimal threshold in the interval [low, high].
    """
    #1. Initial Sampling
    replay_buffer = HERReplayBuffer(capacity=capacity, k_future=k_future,
                                   reward_threshold=initial_threshold, reward_mode=reward_mode)
    collect_her_samples(dataset_collection, replay_buffer,
                        min_history_length=min_history_length,
                        max_history_length=max_history_length,
                        future_length=future_length)

    hit = sum(1 for item in replay_buffer.buffer if item.reward == 0)
    total = len(replay_buffer.buffer)
    hit_ratio = hit / total if total > 0 else 0

    print(f"Initial threshold={initial_threshold}, hit/all={hit_ratio:.4f}, total={total}")

    #2. Dynamically expand the high end (if the hit rate is too low)
    low = initial_threshold
    high = initial_threshold * 10
    while hit_ratio < target_hit_ratio and high < max_high_limit:
        replay_buffer = HERReplayBuffer(capacity=capacity, k_future=k_future,
                                       reward_threshold=high, reward_mode=reward_mode)
        collect_her_samples(dataset_collection, replay_buffer,
                            min_history_length=min_history_length,
                            max_history_length=max_history_length,
                            future_length=future_length)

        hit = sum(1 for item in replay_buffer.buffer if item.reward == 0)
        total = len(replay_buffer.buffer)
        hit_ratio = hit / total if total > 0 else 0

        print(f"Expand high: threshold={high}, hit/all={hit_ratio:.4f}, total={total}")

        if hit_ratio >= target_hit_ratio:
            break

        high = min(high * 10, max_high_limit)

    #3. Dynamically shrink the low end (if the hit rate is too high)
    if hit_ratio > target_hit_ratio:
        low = initial_threshold / 10
        high = initial_threshold
        while low > min_low_limit and hit_ratio > target_hit_ratio:
            replay_buffer = HERReplayBuffer(capacity=capacity, k_future=k_future,
                                           reward_threshold=low, reward_mode=reward_mode)
            collect_her_samples(dataset_collection, replay_buffer,
                                min_history_length=min_history_length,
                                max_history_length=max_history_length,
                                future_length=future_length)

            hit = sum(1 for item in replay_buffer.buffer if item.reward == 0)
            total = len(replay_buffer.buffer)
            hit_ratio = hit / total if total > 0 else 0

            print(f"Shrink low: threshold={low}, hit/all={hit_ratio:.4f}, total={total}")

            if hit_ratio <= target_hit_ratio:
                break

            low /= 10

    #4. Dichotomous search
    for i in range(max_iter):
        mid = (low + high) / 2

        replay_buffer = HERReplayBuffer(capacity=capacity, k_future=k_future,
                                       reward_threshold=mid, reward_mode=reward_mode)
        collect_her_samples(dataset_collection, replay_buffer,
                            min_history_length=min_history_length,
                            max_history_length=max_history_length,
                            future_length=future_length)

        hit = sum(1 for item in replay_buffer.buffer if item.reward == 0)
        total = len(replay_buffer.buffer)
        hit_ratio = hit / total if total > 0 else 0

        print(f"Iteration {i+1}: threshold={mid:.6f}, hit/all={hit_ratio:.4f}, total={total}")

        if abs(hit_ratio - target_hit_ratio) < hit_tolerance:
            print(f"Converged at iteration {i+1} with hit_ratio={hit_ratio:.4f}")
            return mid

        if hit_ratio < target_hit_ratio:
            low = mid
        else:
            high = mid

    return (low + high) / 2

