import os
from openai import OpenAI
import json
import pandas as pd
import numpy as np
from datasets import load_dataset
import time
from typing import List, Dict, Tuple
import logging
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
import random
import re
from collections import defaultdict

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class EnhancedSemanticBenchmarkBuilder:
    def __init__(self, api_key: str, base_url: str = " "):
        self.client = OpenAI(
            api_key=api_key,
            base_url=base_url
        )
        print("Loading SimCSE model...")
        self.similarity_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')

        self.task_categories = {
            "chat": [
                "alpacaeval-easy", "alpacaeval-length", "alpacaeval-hard", 
                "mt-bench-easy", "mt-bench-medium"
            ],
            "chat_hard": [
                "mt-bench-hard", "llmbar-natural", "llmbar-adver-neighbor", 
                "llmbar-adver-GPTInst", "llmbar-adver-GPTOut", "llmbar-adver-manual"
            ],
            "safety": [
                "refusals-dangerous", "refusals-offensive", "xstest-should-refuse", 
                "xstest-should-respond", "do not answer"
            ],
            "reasoning": [
                "math-prm", "hep-cpp", "hep-go", "hep-java", 
                "hep-js", "hep-python", "hep-rust"
            ]
        }

        self.format_preservation_instructions = """
Task Content:
I aim to rework the content into a rendition that retains a comparable semantic meaning, with the additional requirement of
preserving certain elements of its original stylistic features.
CRITICAL REQUIREMENTS:
Maintain exact tone and style (formal/informal/humorous/serious)
Preserve all markdown formatting (headers, bold, italic, links, code blocks)
Keep all list structures intact (bullet points, numbered lists, nested lists)
Maintain all LaTeX mathematical expressions unchanged
Preserve paragraph structure and organization
Keep similar length and complexity level
Maintain code syntax highlighting and structure
Preserve special formatting (tables, quotes, etc.)
"""

        self.unified_rewrite_prompts = [
    f"""{self.format_preservation_instructions}

Please rewrite the following response using synonyms and alternative expressions while maintaining the exact same meaning and all formatting requirements above.

Original response: {{original}}

Rewritten version:""",

    f"""{self.format_preservation_instructions}

Please reword the following response using different phrasing while preserving all original meaning and formatting requirements above.

Original response: {{original}}

Reworded version:""",

    f"""{self.format_preservation_instructions}

Please restructure the sentences in the following response while preserving all original meaning and formatting requirements above.


Original response: {{original}}

Restructured version:""",

    f"""{self.format_preservation_instructions}

Please rephrase the following response using different expressions while keeping identical meaning and all formatting requirements above.

Original response: {{original}}

Rephrased version:""",

    f"""{self.format_preservation_instructions}

Please create a semantically equivalent version of the following response while preserving all formatting requirements above.

Original response: {{original}}

Alternative version:""",

    f"""{self.format_preservation_instructions}

Please reword the following response by adding or removing a small amount of non-essential details while using different phrasing and preserving all original meaning and formatting requirements above.

Original response: {{original}}

Reworded version:"""
]
    
    def load_rewardbench_data_by_category(self, category: str, samples_per_subset: int = 500):
        if category not in self.task_categories:
            raise ValueError(f"Unknown category: {category}. Supported: {list(self.task_categories.keys())}")
        
        try:
            print(f"Loading RewardBench dataset - category: {category}")
            dataset = load_dataset("allenai/reward-bench")
            df = pd.DataFrame(dataset['filtered'])
            
            target_subsets = self.task_categories[category]
            category_data = df[df['subset'].isin(target_subsets)]
            
            print(f"Category {category} includes subsets: {target_subsets}")
            print(f"Total samples: {len(category_data)}")
            
            sampled_data = []
            subset_stats = {}
            
            for subset in target_subsets:
                subset_data = category_data[category_data['subset'] == subset]
                if len(subset_data) > 0:
                    if len(subset_data) > samples_per_subset:
                        subset_sample = subset_data.sample(n=samples_per_subset, random_state=42)
                    else:
                        subset_sample = subset_data
                    
                    sampled_data.append(subset_sample)
                    subset_stats[subset] = len(subset_sample)
                    print(f"  {subset}: {len(subset_sample)} samples")
                else:
                    print(f"  {subset}: no data")
            
            if sampled_data:
                final_df = pd.concat(sampled_data, ignore_index=True)
                print(f"Successfully loaded {len(final_df)} samples")
                return final_df, subset_stats
            else:
                print(f"No data found for category {category}")
                return None, {}
                
        except Exception as e:
            logger.error(f"Failed to load data: {e}")
            return None, {}
    
    def generate_paraphrases(self, text: str, num_paraphrases: int = 6, 
                           model: str = "gpt-4o-2024-08-06", temperature: float = 0.8) -> List[Dict]:
        paraphrases = []
        
        for i in range(num_paraphrases):
            try:
                prompt_template = self.unified_rewrite_prompts[i % len(self.unified_rewrite_prompts)]
                prompt = prompt_template.format(original=text)
                
                response = self.client.chat.completions.create(
                    model=model,
                    messages=[
                        {
                            "role": "system", 
                            "content": "You are an expert at semantic text transformation. Your goal is to create semantically equivalent versions while preserving all formatting, tone, and structural elements exactly as specified."
                        },
                        {"role": "user", "content": prompt}
                    ],
                    max_tokens=min(len(text.split()) * 3, 4000),
                    temperature=temperature,
                    top_p=0.9,
                    presence_penalty=0.1,
                    frequency_penalty=0.1
                )
                
                paraphrase_text = response.choices[0].message.content.strip()
                similarity = self.calculate_similarity(text, paraphrase_text)
                
                paraphrases.append({
                    "text": paraphrase_text,
                    "similarity": similarity,
                    "prompt_template_index": i % len(self.unified_rewrite_prompts),
                    "temperature": temperature,
                    "status": "generated"
                })
                
                print(f"  Generated paraphrase {i+1}/{num_paraphrases} (template: {i % len(self.unified_rewrite_prompts)}, similarity: {similarity:.3f})")
                time.sleep(0.5)
                
            except Exception as e:
                logger.warning(f"Failed to generate paraphrase {i+1}: {e}")
                continue
        
        return paraphrases
    
    def calculate_similarity(self, text1: str, text2: str) -> float:
        try:
            embeddings = self.similarity_model.encode([text1, text2])
            similarity = np.dot(embeddings[0], embeddings[1]) / (
                np.linalg.norm(embeddings[0]) * np.linalg.norm(embeddings[1])
            )
            return float(similarity)
        except:
            return 0.0
    
    def check_factual_consistency_with_gemini(self, original: str, paraphrase: str, 
                                            gemini_model: str = "gemini-2.5-flash-all") -> Dict:
        try:
            consistency_prompt = f"""You are a strict factual consistency expert. Your task is to meticulously compare the original text and the paraphrased version for any discrepancies in facts, even subtle ones. Be critical: do not assume consistency unless every detail matches perfectly.

First, list all key factual elements from the original text (e.g., claims, numbers, dates, names, events, specific details).
Second, list all key factual elements from the paraphrased text.
Third, compare them element-by-element, noting any additions, omissions, alterations, or nuances that could change meaning (e.g., a word change implying different intent).
Finally, rate the factual consistency on a scale of 0-1 based on the comparison.

Scoring guidelines (be strict - deduct points for any mismatch):
- 1.0 = Completely factually consistent (every fact, claim, and detail matches exactly, no variations at all)
- 0.8-0.9 = Mostly consistent (very minor variations like synonyms that don't alter meaning or implications)
- 0.6-0.7 = Partially consistent (some factual differences, omissions, or additions that slightly change interpretation)
- 0.4-0.5 = Inconsistent (significant factual differences, alterations, or contradictions)
- 0.0-0.3 = Completely inconsistent (major contradictions or entirely different facts)

Examples:
Example 1:
Original: "The event occurred on July 4, 2023, in New York, attended by 500 people."
Paraphrased: "The gathering happened on Independence Day 2023 in NYC, with about 500 attendees."
1. Consistency score (0-1): 0.9
2. Explanation: Dates match (July 4 is Independence Day), locations match (NYC is New York), numbers are exact. Minor variations in wording (event/gathering, people/attendees) don't change meaning, but 'about' introduces slight approximation.
3. Any factual differences found: Slight approximation in attendance ('about 500' vs. exact '500').

Example 2:
Original: "Apple released the iPhone 15 in September 2023, featuring a 48MP camera."
Paraphrased: "In 2023, Apple launched their new phone with a high-resolution camera."
1. Consistency score (0-1): 0.6
2. Explanation: Month is omitted (September missing), model name changed vaguely ('iPhone 15' to 'new phone'), camera spec generalized ('48MP' to 'high-resolution'). These omissions and generalizations alter specific details.
3. Any factual differences found: Omission of exact month, vague model name, loss of precise camera specs.

Original text:
{original}

Paraphrased text:
{paraphrase}

Respond exactly with:
1. Consistency score (0-1): 
2. Explanation:
3. Any factual differences found:

After your response, self-criticize: Did I miss any subtle differences? If yes, adjust the score."""

            response = self.client.chat.completions.create(
                model=gemini_model,
                messages=[
                    {
                        "role": "user",
                        "content": consistency_prompt
                    }
                ],
                max_tokens=1000,
                temperature=0.1
            )
            
            response_text = response.choices[0].message.content
            
            lines = response_text.strip().split('\n')
            consistency_score = 0.0
            explanation = ""
            differences = ""
            
            for line in lines:
                if 'consistency score' in line.lower() or line.startswith('1.'):
                    import re
                    score_match = re.search(r':\s*([0-1](?:\.\d+)?)', line)
                    if score_match:
                        consistency_score = float(score_match.group(1))
                        consistency_score = min(1.0, max(0.0, consistency_score))
                elif 'explanation' in line.lower() or line.startswith('2.'):
                    explanation = line.split(':', 1)[-1].strip()
                elif 'differences' in line.lower() or line.startswith('3.'):
                    differences = line.split(':', 1)[-1].strip()
            
            return {
                "consistency_score": consistency_score,
                "explanation": explanation,
                "differences": differences,
                "full_response": response_text,
                "status": "success"
            }
            
        except Exception as e:
            logger.warning(f"Gemini factual consistency check failed: {e}")
            return {
                "consistency_score": 0.5,
                "explanation": f"Check failed: {str(e)}",
                "differences": "Unable to check",
                "full_response": "",
                "status": "failed"
            }
    
    def filter_paraphrases(self, original: str, paraphrases: List[Dict], 
                          min_similarity: float = 0.8, min_factual_consistency: float = 0.8) -> List[Dict]:
        print(f"Starting factual consistency checks...")
        
        filtered = []
        
        for i, paraphrase in enumerate(paraphrases):
            print(f"  Checking paraphrase {i+1}/{len(paraphrases)}...")
            
            similarity = paraphrase["similarity"]
            factual_check = self.check_factual_consistency_with_gemini(original, paraphrase["text"])
            
            is_valid = (
                similarity >= min_similarity and
                factual_check["consistency_score"] >= min_factual_consistency
            )
            
            paraphrase.update({
                "factual_consistency": factual_check,
                "status": "valid" if is_valid else "filtered_out",
                "filter_reason": "" if is_valid else f"sim={similarity:.3f}, fact={factual_check['consistency_score']:.3f}"
            })
            
            filtered.append(paraphrase)
            
            status_emoji = "✅" if is_valid else "❌"
            print(f"  {status_emoji} Similarity: {similarity:.3f}, Factual consistency: {factual_check['consistency_score']:.3f}")
            time.sleep(1.0)
        
        return filtered
    
    def process_category_dataset(self, category: str, output_path: str = None,
                               samples_per_subset: int = 500, num_paraphrases: int = 6,
                               model: str = "gpt-4o-2024-08-06", temperature: float = 0.8) -> Dict:
        if output_path is None:
            output_path = f"semantic_benchmark_{category}.json"
        
        df, subset_stats = self.load_rewardbench_data_by_category(category, samples_per_subset)
        
        if df is None or len(df) == 0:
            print(f"No data available for category {category}")
            return None
        
        results = []
        stats = {
            "category": category,
            "total_processed": 0,
            "successful": 0,
            "failed": 0,
            "valid_paraphrases": 0,
            "avg_similarity": 0.0,
            "avg_factual_consistency": 0.0,
            "subset_stats": subset_stats
        }
        
        print(f"\nProcessing category {category} - {len(df)} samples...")
        
        for idx, row in tqdm(df.iterrows(), total=len(df), desc=f"Processing {category} data"):
            try:
                prompt = row.get('prompt', '')
                chosen = row.get('chosen', '')
                subset = row.get('subset', '')
                
                if not prompt or not chosen:
                    logger.warning(f"Sample {idx} missing required fields")
                    stats["failed"] += 1
                    continue
                
                print(f"\nProcessing sample {idx+1}/{len(df)} (subset: {subset})")
                print(f"Prompt: {prompt[:100]}...")
                print(f"Original: {chosen[:100]}...")
                
                paraphrases = self.generate_paraphrases(chosen, num_paraphrases, model, temperature)
                
                if not paraphrases:
                    logger.warning(f"Sample {idx} failed to generate paraphrases")
                    stats["failed"] += 1
                    continue
                
                filtered_paraphrases = self.filter_paraphrases(chosen, paraphrases)
                valid_paraphrases = [p for p in filtered_paraphrases if p["status"] == "valid"]
                
                result_item = {
                    "id": f"{category}_sample_{idx}",
                    "subset": subset,
                    "prompt": prompt,
                    "original_response": chosen,
                    "paraphrases": filtered_paraphrases,
                    "valid_paraphrases_count": len(valid_paraphrases),
                    "metadata": {
                        "category": category,
                        "subset": subset,
                        "source_dataset": "reward-bench",
                        "original_index": idx,
                        "processing_model": model,
                        "temperature": temperature,
                        "avg_similarity": np.mean([p["similarity"] for p in valid_paraphrases]) if valid_paraphrases else 0.0,
                        "avg_factual_consistency": np.mean([p["factual_consistency"]["consistency_score"] for p in valid_paraphrases]) if valid_paraphrases else 0.0
                    }
                }
                
                results.append(result_item)
                
                stats["successful"] += 1
                stats["valid_paraphrases"] += len(valid_paraphrases)
                if valid_paraphrases:
                    stats["avg_similarity"] += np.mean([p["similarity"] for p in valid_paraphrases])
                    stats["avg_factual_consistency"] += np.mean([p["factual_consistency"]["consistency_score"] for p in valid_paraphrases])
                
                print(f"Generated {len(paraphrases)} paraphrases, {len(valid_paraphrases)} valid")
                
                for i, p in enumerate(valid_paraphrases[:2]):
                    print(f"Paraphrase{i+1} (similarity: {p['similarity']:.3f}, factual consistency: {p['factual_consistency']['consistency_score']:.3f})")
                    print(f"   {p['text'][:80]}...")
                
            except Exception as e:
                logger.error(f"Error processing sample {idx}: {e}")
                stats["failed"] += 1
                continue
        
        stats["total_processed"] = len(df)
        if stats["successful"] > 0:
            stats["avg_similarity"] /= stats["successful"]
            stats["avg_factual_consistency"] /= stats["successful"]
        
        final_data = {
            "metadata": {
                "created_at": time.strftime("%Y-%m-%d %H:%M:%S"),
                "category": category,
                "source": "reward-bench",
                "processing_model": model,
                "temperature": temperature,
                "num_paraphrases_per_sample": num_paraphrases,
                "similarity_threshold": {"min": 0.8},
                "factual_consistency_threshold": 0.8,
                "statistics": stats,
                "task_subsets": self.task_categories[category]
            },
            "data": results
        }
        
        with open(output_path, 'w', encoding='utf-8') as f:
            json.dump(final_data, f, ensure_ascii=False, indent=2)
        
        print(f"\nCategory {category} processing complete!")
        print(f"Successfully processed: {stats['successful']}/{stats['total_processed']}")
        print(f"Total valid paraphrases: {stats['valid_paraphrases']}")
        print(f"Average similarity: {stats['avg_similarity']:.3f}")
        print(f"Average factual consistency: {stats['avg_factual_consistency']:.3f}")
        print(f"Data saved to: {output_path}")
        
        return final_data
    
    def process_all_categories(self, output_dir: str = "benchmark_data", 
                             samples_per_subset: int = 500, num_paraphrases: int = 6,
                             model: str = "gpt-4o-2024-08-06", temperature: float = 0.8):
        os.makedirs(output_dir, exist_ok=True)
        
        all_results = {}
        
        for category in self.task_categories.keys():
            print(f"\n{'='*60}")
            print(f"Processing category: {category.upper()}")
            print(f"{'='*60}")
            
            output_path = os.path.join(output_dir, f"semantic_benchmark_{category}.json")
            
            result = self.process_category_dataset(
                category=category,
                output_path=output_path,
                samples_per_subset=samples_per_subset,
                num_paraphrases=num_paraphrases,
                model=model,
                temperature=temperature
            )
            
            if result:
                all_results[category] = result
        
        self.generate_summary_report(all_results, os.path.join(output_dir, "summary_report.json"))
        
        return all_results
    
    def generate_summary_report(self, all_results: Dict, output_path: str):
        summary = {
            "created_at": time.strftime("%Y-%m-%d %H:%M:%S"),
            "categories_processed": list(all_results.keys()),
            "overall_statistics": {},
            "category_details": {}
        }
        
        total_samples = 0
        total_valid_paraphrases = 0
        total_successful = 0
        
        for category, data in all_results.items():
            stats = data["metadata"]["statistics"]
            summary["category_details"][category] = {
                "samples_processed": stats["total_processed"],
                "successful": stats["successful"],
                "valid_paraphrases": stats["valid_paraphrases"],
                "avg_similarity": stats["avg_similarity"],
                "avg_factual_consistency": stats["avg_factual_consistency"],
                "subsets": stats["subset_stats"]
            }
            
            total_samples += stats["total_processed"]
            total_successful += stats["successful"]
            total_valid_paraphrases += stats["valid_paraphrases"]
        
        summary["overall_statistics"] = {
            "total_samples_processed": total_samples,
            "total_successful": total_successful,
            "total_valid_paraphrases": total_valid_paraphrases,
            "overall_success_rate": total_successful / total_samples if total_samples > 0 else 0,
            "avg_valid_paraphrases_per_sample": total_valid_paraphrases / total_successful if total_successful > 0 else 0
        }
        
        with open(output_path, 'w', encoding='utf-8') as f:
            json.dump(summary, f, ensure_ascii=False, indent=2)
        
        print(f"\nSummary report saved to: {output_path}")
        print(f"Overall statistics:")
        print(f"   Samples processed: {total_samples}")
        print(f"   Success rate: {summary['overall_statistics']['overall_success_rate']:.1%}")
        print(f"   Average valid paraphrases per sample: {summary['overall_statistics']['avg_valid_paraphrases_per_sample']:.1f}")

def main():
    API_KEY = ""
    BASE_URL = ""
    MODEL = "gpt-4o"
    TEMPERATURE = 0.8
    SAMPLES_PER_SUBSET = 100
    NUM_PARAPHRASES = 6
    OUTPUT_DIR = ""
    
    builder = EnhancedSemanticBenchmarkBuilder(API_KEY, BASE_URL)
    
    print("Select processing mode:")
    print("1. Process single category")
    print("2. Process all categories")
    
    choice = input("Enter choice (1 or 2): ").strip()
    
    if choice == "1":
        print("\nAvailable categories:")
        for i, category in enumerate(builder.task_categories.keys(), 1):
            print(f"{i}. {category}")
        
        cat_choice = input("Enter category number: ").strip()
        try:
            category = list(builder.task_categories.keys())[int(cat_choice) - 1]
            builder.process_category_dataset(
                category=category,
                samples_per_subset=SAMPLES_PER_SUBSET,
                num_paraphrases=NUM_PARAPHRASES,
                model=MODEL,
                temperature=TEMPERATURE
            )
        except (ValueError, IndexError):
            print("Invalid choice")
    
    elif choice == "2":
        builder.process_all_categories(
            output_dir=OUTPUT_DIR,
            samples_per_subset=SAMPLES_PER_SUBSET,
            num_paraphrases=NUM_PARAPHRASES,
            model=MODEL,
            temperature=TEMPERATURE
        )
    
    else:
        print("Invalid choice")

if __name__ == "__main__":
    main()