import torch
import numpy as np
from typing import Dict, List, Tuple
import re
from collections import defaultdict
import random

def random_output(prob_minus_one):
    if not (0 <= prob_minus_one <= 1):
        raise ValueError("0~1")
    return 1 if random.random() < prob_minus_one else -1

def calculate_retention_rate(positive_rate, completed_rate):

    if positive_rate < 0.4 or completed_rate < 0.1:
        return 1
    
    elif completed_rate < 0.6:
        return 1-completed_rate*1.5
    
    else:
        return 0.1
def reward_func(reward_metadata: Dict) -> torch.Tensor:


    task_ids = reward_metadata['task_ids']
    turn_indices = reward_metadata['turn_indices']
    trajectory_infos = reward_metadata['trajectory_infos']
    env_feedbacks = reward_metadata['env_feedbacks']
    str_infos = reward_metadata['str_info'] 
    total_samples = len(task_ids)
    assert len(turn_indices) == total_samples
    assert len(trajectory_infos) == total_samples
    assert len(env_feedbacks) == total_samples
    assert len(str_infos) == total_samples

    task_rollout_map = {}  
    for idx in range(total_samples):
        task_id = task_ids[idx]
        traj_info = trajectory_infos[idx]
        rollout_idx = traj_info['rollout_idx']
        str_info = str_infos[idx]
        

        if task_id not in task_rollout_map:
            task_rollout_map[task_id] = {}
        

        if rollout_idx not in task_rollout_map[task_id]:
            task_rollout_map[task_id][rollout_idx] = {
                'completed': traj_info['task_completed'],
                'total_turns': traj_info['total_turns'],
                'samples': [],
                'repetition_penalties': [] 
            }
        

        task_rollout_map[task_id][rollout_idx]['samples'].append({
            'global_idx': idx,
            'turn_idx': turn_indices[idx],
            'turn_reward': env_feedbacks[idx]['turn_reward'],
            'think_too_long': env_feedbacks[idx]['think_too_long'],
            'str_info': str_info,
        })

    all_rollouts = [data for task in task_rollout_map.values() for data in task.values()]

    batch_success_rate = sum(r['completed'] for r in all_rollouts) / len(all_rollouts) if all_rollouts else 0.0
    positive_turn_rate = sum(1 for f in env_feedbacks if f['turn_reward'] > 0) / total_samples if total_samples > 0 else 0.0


    turn_reward_rate = calculate_retention_rate(positive_turn_rate, batch_success_rate)
    print(f"Batch Success Rate: {batch_success_rate:.4f}, Positive Turn Rate: {positive_turn_rate:.4f}, Turn Reward Rate: {turn_reward_rate:.4f}")
    

    repetition_penalties_all = {} 
    for task_id, rollouts in task_rollout_map.items():
        for rollout_idx, rollout_data in rollouts.items():
            penalties, repetition_count= detect_repetition_penalty(rollout_data)
            rollout_data['repetition_penalties'] = penalties
            rollout_data['repetition_count'] = repetition_count 
            repetition_penalties_all.update(penalties) 
            if penalties:
                print(f"Task {task_id} Rollout {rollout_idx}: 检测到 {repetition_count} 个重复样本")

    print(f'重复样本总数: {len(repetition_penalties_all)}')
    print(f"任务数量: {len(task_rollout_map)}")
    print(f"每个任务的平均turns: {np.mean([rollout['total_turns'] for r in task_rollout_map.values() for rollout in r.values()]):.2f}")

    
    trajectory_rewards = {}  
    
    for task_id, rollouts in task_rollout_map.items():

        rollout_results = []
        for rollout_idx, rollout_data in rollouts.items():
            rollout_results.append({
                'rollout_idx': rollout_idx,
                'completed': rollout_data['completed'],
                'total_turns': rollout_data['total_turns'],
                'num_errors': sum(1 for s in rollout_data['samples'] 
                                if s['turn_reward'] < 0 and not s['think_too_long']),
                'repetition_count': rollout_data['repetition_count'] 
            })

        
        n_rollouts = len(rollout_results)
        
        if n_rollouts == 1:
  
            rollout = rollout_results[0]
            if rollout['completed']:
                grpo_reward = 1
            else:
                grpo_reward = -1
            trajectory_rewards[(task_id, rollout['rollout_idx'])] = grpo_reward
        else:

            scores = []
            for rollout in rollout_results:
                if rollout['completed']:

                    score = 1.0
                else:

                    progress_score = min(rollout['total_turns'] / 16, 0.5)
                    error_penalty = min(rollout['num_errors'] * 0.08, 1.0)
                    repetition_penalty = min(rollout['repetition_count'] * 0.1, 0.8)  
                    score = -1.0 + progress_score - error_penalty - repetition_penalty
                scores.append(score)
            

            mean_score = np.mean(scores)
            std_score = np.std(scores) if np.std(scores) > 0 else 1.0
            
            rollout_records = []
            for i, rollout in enumerate(rollout_results):

                normalized_score = (n_rollouts / (n_rollouts - 1)) * (scores[i] - mean_score)
                
                trajectory_rewards[(task_id, rollout['rollout_idx'])] = normalized_score
                rollout_records.append(normalized_score)
            print(f"Task {task_id} GRPO Rewards: {rollout_records}, mean score: {mean_score:.4f}, std: {std_score:.4f}")




    
    final_rewards = torch.zeros(total_samples, dtype=torch.float32)
    
   
    for task_id, rollouts in task_rollout_map.items():
        for rollout_idx, rollout_data in rollouts.items():

            traj_reward = trajectory_rewards[(task_id, rollout_idx)]

            sorted_samples = sorted(rollout_data['samples'], key=lambda x: x['turn_idx'])
            
            valid_samples = [s for s in sorted_samples if not s['think_too_long']]
            think_too_long_samples = [s for s in sorted_samples if s['think_too_long']]

            if valid_samples:
                turn_rewards = [s['turn_reward'] for s in valid_samples]

                smoothed_rewards = turn_rewards 

                logic_bonuses = calculate_logic_bonus(turn_rewards)

                for i, sample in enumerate(valid_samples):
                    global_idx = sample['global_idx']
                    base_reward = smoothed_rewards[i] + logic_bonuses[i]

                    if global_idx in repetition_penalties_all:
                        repetition_penalty = repetition_penalties_all[global_idx]

                        base_reward = base_reward * repetition_penalty
                        

                        if base_reward >= 0:
                            base_reward = -abs(base_reward)  
                    else:

                        base_reward = base_reward

                    abs_rate = random_output(turn_reward_rate)  
       
                    if traj_reward >= 0: 
                        if base_reward < 0:
                            penalty_factor = 1 / (1 + traj_reward) 
                            final_rewards[global_idx] = base_reward * 0.5 * penalty_factor
                        else:
                            final_rewards[global_idx] = base_reward * traj_reward
                    else:  
                        if base_reward > 0:
                            if abs_rate < 0:
                                final_rewards[global_idx] = 0.3*(1/base_reward)* traj_reward
                                if final_rewards[global_idx] > 0:
                                    final_rewards[global_idx] = -abs(final_rewards[global_idx])
                            else:
                                penalty_factor = 1 / (1 + abs(traj_reward))
                                final_rewards[global_idx] = 0.3 * base_reward * penalty_factor

                        else:
                            final_rewards[global_idx] = base_reward * abs(traj_reward)
   
            for sample in think_too_long_samples:
                final_rewards[sample['global_idx']] = -3.0
    
    final_rewards = torch.clamp(final_rewards, min=-3.0, max=3.0)

    assert final_rewards.shape == (total_samples,)
    assert not torch.any(torch.isnan(final_rewards))
    

    print(f"GRPO Reward statistics:")
    print(f"  Total samples: {total_samples}")
    print(f"  Total tasks: {len(task_rollout_map)}")
    print(f"  Avg rollouts per task: {np.mean([len(r) for r in task_rollout_map.values()]):.2f}")
    print(f"  Positive rewards: {(final_rewards > 0).sum().item()}")
    print(f"  Negative rewards: {(final_rewards < 0).sum().item()}")
    print(f"  Mean reward: {final_rewards.mean().item():.4f}")
    print(f"  Std reward: {final_rewards.std().item():.4f}")
    
    return final_rewards



def calculate_logic_bonus(rewards: List[float]) -> List[float]:
    bonuses = [0.0] * len(rewards)
    
    for i in range(len(rewards)):

        if i > 0 and rewards[i-1] < 0 and rewards[i] > 0:
            bonuses[i] += 0.1
        

        if i >= 1 and rewards[i-1] < 0 and rewards[i] < 0:
            bonuses[i] -= 0.1
    
    return bonuses

def detect_repetition_patterns(samples: List[Dict]) -> Dict[int, int]:

    repetition_counts = {}
    consecutive_count = 0
    
    for i in range(1, len(samples)):
        prev_response = samples[i-1]['response'].strip()
        curr_response = samples[i]['response'].strip()
        

        if curr_response == prev_response:
            consecutive_count += 1
            repetition_counts[i-1] = max(repetition_counts.get(i-1, 1), consecutive_count)
            repetition_counts[i] = consecutive_count + 1
        else:
            consecutive_count = 0
            
    return repetition_counts

def calculate_repetition_penalty(
    consecutive_count: int, 
    base_penalty: float = 0.2,
    penalty_growth_rate: float = 1.5
) -> float:

    if consecutive_count <= 2:
        return 0
    
    # 使用指数增长惩罚（从第二次重复开始）
    return base_penalty * (penalty_growth_rate ** (consecutive_count - 1))
def extract_code_from_response(response: str) -> str:

    lines = response.strip().split('\n')
    code_lines = []
    for line in lines:

        stripped = line.strip()
        if stripped and not stripped.startswith('#'):
            code_part = stripped.split('#')[0].strip()
            if code_part:
                code_lines.append(code_part)
    return '\n'.join(code_lines)


def detect_repetition_penalty(rollout_data: dict):
    penalties = {}

    
    positive_samples = [s for s in rollout_data['samples'] 
                       if s['turn_reward'] > 0 and not s['think_too_long']]
    
    if len(positive_samples) < 2:
        return penalties,0

    positive_samples.sort(key=lambda x: x['turn_idx'])

    code_sequence = []
    for sample in positive_samples:
        str_info = sample['str_info'] 
        code = extract_code_from_response(str_info['response'])
        code_sequence.append((sample['global_idx'], code))

    last_code = None
    code_occurrences = defaultdict(int)
    
    for idx, (global_idx, code) in enumerate(code_sequence):

        if not code:
            continue 
        
        code_occurrences[code] += 1
        occurrence_count = code_occurrences[code]
        
        if occurrence_count >= 2: 
            penalties[global_idx] = -0.6 * (occurrence_count - 1)
    
    repetition_count = len(penalties) 
    return penalties, repetition_count