import hydra
import numpy as np 
import json
import logging 
import matplotlib.pyplot as plt
import os
import openai
import re
import subprocess
from pathlib import Path
import shutil
import time 
from typing import Dict, List
import datetime
import math
import torch
import torch.nn as nn

from similarity_utils import (
    get_bge_model,
    compute_semantic_similarity,
    compute_string_similarity,
    normalize_reward_name,
    get_reward_description,
    normalize_reward_items,
    find_similar_groups
)

def extract_reward_components(code_string: str) -> Dict[str, float]:
    components = {}
    
   
    weight_pattern = r'(\w+)(?:_temp|_weight)[\s:]*=\s*([-\d.]+)|(\w+)(?:_temp|_weight):\s*float\s*=\s*([-\d.]+)'
    weights = re.findall(weight_pattern, code_string)
    
    for match in weights:
        try:
           
            if match[0]:  
                name, value = match[0], match[1]
            else:  
                name, value = match[2], match[3]
            
         
            value = value.strip()
            try:
             
                components[f"{name}_weight"] = float(value)
            except ValueError:
             
                parts = value.split('.')
                if len(parts) > 2:
                    value = parts[0] + '.' + parts[1]
              
                components[f"{name}_weight"] = float(value)
            
        except (ValueError, IndexError) as e:
            print(f"Warning: Error parsing weight value ({name}: {value}): {str(e)}")
            continue
    
  
    reward_dict_patterns = [
        r'reward_dict\s*=\s*{([^}]*)}',
        r'reward_components\s*=\s*{([^}]*)}',
        r'return\s+[^,]+,\s*{([^}]*)}'
    ]
    
    for pattern in reward_dict_patterns:
        reward_items = re.search(pattern, code_string)
        if reward_items:
            items = reward_items.group(1)
            reward_keys = re.findall(r'["\']([^"\']+)["\']:', items)
            if reward_keys:
                components['reward_items'] = reward_keys
                break
    
    return components


def calculate_reward_frequencies(samples):

    model, tokenizer = get_bge_model()
    
    reward_counts = {}
    total_samples = len(samples)
    

    all_rewards = list({r for s in samples for r in s['reward_items']})
    
    for reward in all_rewards:
        count = 0
        reward_norm = normalize_reward_name(reward)
        reward_desc = get_reward_description(reward)
        
        for sample in samples:
            has_similar = False
            for r in sample['reward_items']:
                norm_r = normalize_reward_name(r)
                desc_r = get_reward_description(r)
                
                text_sim = compute_string_similarity(reward_norm, norm_r)
                semantic_sim = compute_semantic_similarity(reward_desc, desc_r, (model, tokenizer))
                
                if max(text_sim, semantic_sim) >= 0.9:
                    has_similar = True
                    break
            
            if has_similar:
                count += 1
        
        reward_counts[reward] = 1 - (count / total_samples)  
    
    return reward_counts

def calculate_usefulness_scores(samples, frequencies):

    usefulness_scores = []
    
    for idx, sample in enumerate(samples):
        
        score = 0.0
        reward_items = sample['reward_items']
        total_items = len(reward_items)
        
        if total_items > 0:
            for item in reward_items:
                freq = frequencies.get(item, 0.0)
                score += (1.0 - freq)  
            score /= total_items  
        
        usefulness_scores.append({
            'index': idx,
            'usefulness_score': score,
            'reward_items': reward_items
        })
    
  
    scores = np.array([s['usefulness_score'] for s in usefulness_scores])
    score_std = np.std(scores)
    score_range = np.max(scores) - np.min(scores) if len(scores) > 0 else 0
    
    return usefulness_scores, score_std, score_range

def update_reward_weights(code: str, new_weights: Dict[str, float]) -> str:

    for name, value in new_weights.items():
        try:
        
            float_value = float(value)
            if not math.isfinite(float_value): 
                raise ValueError(f"Weight value is not finite: {float_value}")
            new_weights[name] = f"{float_value:.4f}"
        except (ValueError, TypeError) as e:
            raise ValueError(f"Undefined {name}: {value} - {str(e)}")
    
   
    for name, value in new_weights.items():
        if name.endswith('_weight'):
            base_name = name[:-7]  
            for suffix in ['_temp', '_weight']:
                patterns = [
                    f"({base_name}{suffix})\\s*=\\s*[-\\d.]+",
                    f"({base_name}{suffix}):\\s*float\\s*=\\s*[-\\d.]+",
                ]
                for pattern in patterns:
                    code = re.sub(pattern, f"\\1 = {value}", code)
    
    return code