#!/usr/bin/env python
# coding: utf-8

import argparse
import json
import multiprocessing as mp
import os
import shelve
import time
import getpass
from concurrent.futures import ProcessPoolExecutor, as_completed
from pathlib import Path
from typing import Dict, List, Optional, Tuple


import anthropic
from datasets import Dataset
from tqdm import tqdm

# Constants
MAX_TOKENS = 200
DATASET_PATH = "../data/ai_conversation_starters/"

SYSTEM_PROMPT = '''You are an AI evaluation assistant that always responds in valid JSON format.

For each evaluation task, you must return your analysis in the following JSON structure:
{
  "fluency_score": <integer between 1 and 10>,
  "theme": <string description of the identified theme, or "NONE" if none is found>,
  "consistency_score": <integer representing count of responses matching the theme>,
  "explanation": <string with brief reasoning for your scores>
}

Before responding, verify that your output is valid JSON. Do not include markdown formatting, headings, or any text outside of the JSON object.'''

TASK_PROMPT = '''I am evaluating an unsupervised method for eliciting diverse personas from an LLM. Each persona tends to be obsessed with a certain topic, bringing it up in response to almost everything, even when the topic is not relevant to the user's request. ***This is to be expected***, so don't be surprised when you observe abrupt transitions between the user's topic and the persona's preferred topic.

Some personas are less obsessive but still may follow a consistent style across responses (e.g., always speaking in verse, speaking consistently in {first,second} person, speaking more colloquially than a normal chat-bot, or using idiosyncratic, but still fluent punctuation /  formatting).

I need you to evaluate the responses on two categories:
    1. **Consistency** - How consistently the LLM incorporates a single topic/theme/style in its responses.
        *Note*: If the persona seems to bring up multiple distinct themes, report consistency for the most common theme.
        *Note*: Gibberish, repetitive or broken language does *not* count as a specific theme. The theme should relate to a certain topic or style of communication.
    2. **Fluency** - How fluent are the LLM's responses? Evaluate fluency *ONLY* on things like grammar/punctuation/readability. 
        *NOTE*: It's OK if the responses splice together seemingly incoherent themes/topics - *DO NOT* penalize the model for abrupt transitions between the user's topic and the persona's preferred topic, although you should reward responses which more creatively incorporate the persona's preferred topic in the context of the user's prompt.
        *Note*: The LLM's responses have been truncated at 512 tokens. DO NOT penalize fluency simply because the response has been truncated. 
        *Note*: Highly repetitive responses should be penalized.
        *Note*: Very short responses should be penalized.

Analyze the document and provide your evaluation in the required JSON format.
'''


def create_document(responses: List[str], dataset: Dataset) -> str:
    """Create a document from responses and dataset prompts."""
    s = ""
    for i in range(len(responses)):
        s += f"###<CONVERSATION {i}>###\n"
        s += "\n<User>\n"
        s += dataset.select([i])["prompt"][0]
        s += "\n\n<Assistant>\n"
        s += responses[i] 
        s += "\n\n"
    return s


def submit_messages(client: anthropic.Anthropic, document: str) -> anthropic.types.Message:
    """Submit messages to the Anthropic API."""
    response = client.messages.create(
        model="claude-3-7-sonnet-20250219",
        system=SYSTEM_PROMPT,
        max_tokens=MAX_TOKENS,
        temperature=0.0,
        messages=[
            {
                "role": "user", 
                "content": [
                    {
                        "type": "text",
                        "text": TASK_PROMPT
                    },
                    {
                        "type": "text",
                        "text": f"<document>\n{document}\n</document>"
                    }
                ]
            }
        ],
        extra_headers={"anthropic-beta": "prompt-caching-2024-07-31"}
    )
    return response


def process_key_batch(
    keys: List[str], 
    process_id: int, 
    api_key: str,
    output_dir: Path,
    db_path: str
) -> List[Dict]:
    """Process a batch of keys in a single process."""
    client = anthropic.Anthropic(api_key=api_key)
    dataset = Dataset.load_from_disk(DATASET_PATH)
    
    # Open the database
    with shelve.open(db_path, "r") as db:
        results = []
        
        # Create process-specific output file
        process_file = output_dir / f"process_{process_id}_results.json"
        
        try:
            # Load existing results if available
            if process_file.exists():
                with open(process_file, 'r') as f:
                    existing_results = json.load(f)
                    processed_keys = {r['key'] for r in existing_results}
                    results.extend(existing_results)
            else:
                processed_keys = set()
                existing_results = []
            
            # Process each key
            for key in tqdm(keys, desc=f"Process {process_id}", position=process_id):
                if key in processed_keys:
                    continue
                
                try:
                    responses = db[key]
                    document = create_document(responses, dataset)
                    message = submit_messages(client, document)
                    
                    result = {
                        'key': key,
                        'content': message.content[0].text,
                        'model': message.model,
                        'stop_reason': message.stop_reason,
                        'usage': {
                            'input_tokens': message.usage.input_tokens,
                            'output_tokens': message.usage.output_tokens,
                            'cache_creation_input_tokens': message.usage.cache_creation_input_tokens,
                            'cache_read_input_tokens': message.usage.cache_read_input_tokens
                        },
                        'timestamp': time.time()
                    }
                    
                    results.append(result)
                    
                    # Save intermediate results
                    with open(process_file, 'w') as f:
                        json.dump(results, f, indent=2)
                    
                except Exception as e:
                    print(f"Error processing key {key} in process {process_id}: {e}")
                    
        except Exception as e:
            print(f"Error in process {process_id}: {e}")
    
    return results


def main():
    parser = argparse.ArgumentParser(description='Auto interpretation script for LLM responses')
    parser.add_argument('--db-path', type=str, help='Path to the results database')
    parser.add_argument('--filter', type=str, help='Filter expression for keys (e.g., "key.startswith(\'scale_1.0\')")')
    parser.add_argument('--limit', type=int, help='Maximum number of keys to process')
    parser.add_argument('--processes', type=int, default=4, help='Number of parallel processes')
    parser.add_argument('--output-dir', type=str, help='Output directory for results')
    
    args = parser.parse_args()
    
    # Get API key
    api_key = os.getenv("ANTHROPIC_API_KEY")
    if not api_key:
        api_key = getpass.getpass("Enter your API key: ")
    
    # Create output directory
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Load keys from database
    with shelve.open(args.db_path, "r") as db:
        all_keys = list(db.keys())
    
    # Apply filter if specified
    if args.filter:
        filtered_keys = [key for key in all_keys if eval(args.filter)]
    else:
        filtered_keys = all_keys
    
    # Apply limit if specified
    if args.limit:
        filtered_keys = filtered_keys[:args.limit]
    
    print(f"Processing {len(filtered_keys)} keys with {args.processes} processes")
    
    # Split keys into batches for each process
    batch_size = len(filtered_keys) // args.processes + 1
    key_batches = [
        filtered_keys[i:i + batch_size] 
        for i in range(0, len(filtered_keys), batch_size)
    ]
    
    # Process in parallel
    all_results = []
    with ProcessPoolExecutor(max_workers=args.processes) as executor:
        future_to_batch = {
            executor.submit(process_key_batch, batch, i, api_key, output_dir, args.db_path): i
            for i, batch in enumerate(key_batches)
        }
        
        for future in as_completed(future_to_batch):
            process_id = future_to_batch[future]
            try:
                results = future.result()
                all_results.extend(results)
                print(f"Process {process_id} completed with {len(results)} results")
            except Exception as e:
                print(f"Process {process_id} generated an exception: {e}")
    
    # Save consolidated results
    print(f"Saving {len(all_results)} total results")
    
    # Save content and key JSON
    content_json = output_dir / 'all_results_content.json'
    content_data = [
        {
            'key': result['key'],
            'content': result['content'],
            'timestamp': result['timestamp']
        }
        for result in all_results
    ]
    with open(content_json, 'w') as f:
        json.dump(content_data, f, indent=2)
    
    # Save metadata JSON
    metadata_json = output_dir / 'all_results_metadata.json'
    metadata_data = [
        {
            'key': result['key'],
            'model': result['model'],
            'stop_reason': result['stop_reason'],
            'usage': result['usage'],
            'timestamp': result['timestamp']
        }
        for result in all_results
    ]
    with open(metadata_json, 'w') as f:
        json.dump(metadata_data, f, indent=2)
    
    print(f"Results saved to {output_dir}")
    print(f"Content: {content_json}")
    print(f"Metadata: {metadata_json}")


if __name__ == "__main__":
    main()