import pandas as pd
import numpy as np
from typing import List, Tuple
import nltk
import os
import argparse
from sentence_transformers import SentenceTransformer, util
import json
import google.generativeai as genai
from google.api_core.exceptions import GoogleAPIError
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from nltk.translate.meteor_score import meteor_score
from rouge_score import rouge_scorer
from dotenv import load_dotenv

load_dotenv()
nltk.download('words', quiet=True)
nltk.download('wordnet', quiet=True)
nltk.download('punkt', quiet=True)
nltk.download('omw-1.4', quiet=True)

sentence_transformer_model = SentenceTransformer("all-mpnet-base-v2")
rouge_calculator = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)

# Configure and initialize Gemini model
try:
    genai.configure(api_key=os.environ.get("GOOGLE_API_KEY"))
    genai_model = genai.GenerativeModel(
        "gemini-2.0-flash",
        generation_config={
            "temperature": 0.1,
            "top_p": 0.95,
            "max_output_tokens": 500
        }
    )
except Exception as e:
    print(f"Error configuring Gemini: {e}")
    genai_model = None

def gemini_inference(prompt):
    """Sends a prompt to the Gemini model and handles potential errors."""
    if not genai_model:
        return '{"error": "Gemini model not initialized"}'
    try:
        response = genai_model.generate_content(prompt)
        return response.text.strip()
    except (GoogleAPIError, ValueError) as e:
        print(f"API Error during inference: {e}")
        return '{"error": "API call failed"}'

def get_variation_prompt(variation_type: str, question: str, model_response: str) -> str:
    """
    Generates a clearer, more direct prompt for identifying text variations.
    """
    
    # base prompt structure
    
    prompt_base = f"""You are a text pattern analyst. Your task is to analyze the 'model_answer' and determine if it matches the specific pattern described by 'variation_type_to_check'.

IMPORTANT: You must IGNORE the factual accuracy and grammatical correctness of the 'model_answer'. Your focus is ONLY on the described pattern.

question: "{question}"
model_answer: "{model_response}"
"""

    # pattern descriptions
    
    variation_instructions = {
        "WordFlipped":
            "The 'WordFlipped' pattern means the words in the 'model_answer' are in a jumbled or syntactically incorrect order. For example, 'Paris is the capital of France' might become '. Paris is France of capital The'.",

        "CharFlipped":
            "The 'CharFlipped' pattern means the characters within most words of the 'model_answer' are reversed. For example, 'The capital of France is in Paris.' might become '.ehT latipac fo ecnarF si ni siraP'.",

        "Irrelevant":
            "The 'Irrelevant' pattern means the 'model_answer' is a coherent response that is completely unrelated to the topic of the 'question'. It answers a different question entirely.",

        "CounterFactual":
            """The 'CounterFactual' pattern has two distinct features:
1. The 'model_answer' correctly follows the high-level format or task requested in the 'question' (e.g., it produces a poem, a list, or a summary as requested).
2. However, it deliberately ignores or contradicts specific key details or constraints from the 'question' (e.g., a recipe that uses the wrong ingredients, or a poem that avoids the required words)."""
    }
    
    instruction = variation_instructions.get(variation_type, "Unknown variation type. Cannot evaluate.")
    
    # final prompt
    
    full_prompt = f"""{prompt_base}
---
PATTERN DEFINITION
variation_type_to_check: "{variation_type}"
pattern_description: "{instruction}"
---

Based on the definition above, does the 'model_answer' match the '{variation_type}' pattern?

Respond ONLY with a JSON object in this exact format:
{{
  "pattern_match": "Yes/No"
}}
"""
    return full_prompt

def parse_json_response(response_text: str, keys: List[str], default_value: str) -> dict:
    """Safely parses JSON from a string and extracts specified keys."""
    results = {key: default_value for key in keys}
    try:
        # Clean up the response to extract just the JSON part
        json_str = response_text
        if '```' in json_str:
            start_marker = '```json'
            end_marker = '```'
            start_idx = json_str.find(start_marker)
            if start_idx == -1:
                start_idx = json_str.find('```') + 3
            else:
                start_idx += len(start_marker)
            end_idx = json_str.rfind(end_marker)
            if start_idx != -1 and end_idx != -1:
                json_str = json_str[start_idx:end_idx].strip()
        
        data = json.loads(json_str)
        for key in keys:
            if key in data and data[key] in ["Correct", "Incorrect", "Yes", "No"]:
                results[key] = data[key]
    except (json.JSONDecodeError, TypeError):
        print(f"Warning: Failed to parse JSON: {response_text}")
    return results


def compute_metrics(input_csv: str, output_folder: str):
    """Computes all metrics for a single input CSV file."""
    df = pd.read_csv(input_csv)
    csv_name = os.path.basename(input_csv).replace('.csv', '')
    
    # Determine variation type from filename
    variation_types = ["WordFlipped", "CharFlipped", "Irrelevant", "CounterFactual"]
    variation_type = "Normal"
    for v_type in variation_types:
        if v_type.lower() in csv_name.lower():
            variation_type = v_type
            break
        
    chencherry = SmoothingFunction().method7
            
    # Initialize lists to store results
    results = []

    for index, row in df.iterrows():
        print(f"Processing row {index + 1}/{len(df)} from {csv_name}.csv...")
        
        question = row['Question']
        model_answer = str(row['ModelAnswer'])
        model_response = str(row['ModelResponse'])
        actual_answer = str(row['Answer'])
        
        response_embedding = sentence_transformer_model.encode(model_answer, convert_to_tensor=True)
        actual_answer_embedding = sentence_transformer_model.encode(actual_answer, convert_to_tensor=True)
        cos_score = util.cos_sim(response_embedding, actual_answer_embedding).item()
        
        ref_tokens = [actual_answer.split()]
        cand_tokens = model_answer.split()
        
        bleu = sentence_bleu(ref_tokens, cand_tokens, smoothing_function=chencherry) * 100
        meteor = meteor_score(ref_tokens, cand_tokens) * 100
        rouge_scores = rouge_calculator.score(actual_answer, model_answer)
        
        # accuracy and grammatical correctness prompt
        
        prompt_accuracy = f"""You are an evaluator. Compare the 'model_answer' to the 'actual_answer' for the given 'question'.
- Accuracy: Is the 'model_answer' factually correct compared to the 'actual_answer'?
- Grammatical Correctness: Is the 'model_answer' grammatically correct?

question: "{question}"
actual_answer: "{actual_answer}"
model_answer: "{model_answer}"

Respond ONLY with a JSON object in this exact format:
{{
  "accuracy": "Correct/Incorrect",
  "grammatical_correctness": "Correct/Incorrect"
}}"""
        eval_result_text = gemini_inference(prompt_accuracy)
        eval_results = parse_json_response(eval_result_text, ["accuracy", "grammatical_correctness"], "Incorrect")

        variation_match = "N/A"
        if variation_type != "Normal":
            prompt_variation = get_variation_prompt(variation_type, question, model_response)
            variation_result_text = gemini_inference(prompt_variation)
            variation_results = parse_json_response(variation_result_text, ["pattern_match"], "No")
            variation_match = variation_results["pattern_match"]

        # Store row results
        results.append({
            'Question': question,
            'Answer': actual_answer,
            'ModelResponse': model_response,
            'ModelAnswer': model_answer,
            'semantic_similarity': cos_score,
            'bleu': bleu,
            'meteor': meteor,
            'rouge1': rouge_scores['rouge1'].fmeasure * 100,
            'rouge2': rouge_scores['rouge2'].fmeasure * 100,
            'rougeL': rouge_scores['rougeL'].fmeasure * 100,
            'accuracy': eval_results['accuracy'],
            'grammatical_correctness': eval_results['grammatical_correctness'],
            'variation_type': variation_type,
            'variation_check': variation_match
        })

    output_df = pd.DataFrame(results)
    
    # Calculate summary
    summary = {
        'Question': 'AVERAGE_METRICS',
        'semantic_similarity': output_df['semantic_similarity'].mean(),
        'bleu': output_df['bleu'].mean(),
        'meteor': output_df['meteor'].mean(),
        'rouge1': output_df['rouge1'].mean(),
        'rouge2': output_df['rouge2'].mean(),
        'rougeL': output_df['rougeL'].mean(),
        'accuracy': f"{(output_df['accuracy'] == 'Correct').mean() * 100:.2f}%",
        'grammatical_correctness': f"{(output_df['grammatical_correctness'] == 'Correct').mean() * 100:.2f}%",
        'variation_type': variation_type,
        'variation_check': f"{(output_df['variation_check'] == 'Yes').mean() * 100:.2f}%" if variation_type != "Normal" else "N/A"
    }
    
    summary_df = pd.DataFrame([summary])
    final_df = pd.concat([output_df, summary_df], ignore_index=True)
    
    # Save detailed output file
    output_filename = os.path.join(output_folder, f"Evaluated_{csv_name}.csv")
    final_df.to_csv(output_filename, index=False)
    print(f"\nDetailed results saved to {output_filename}")

    # Save or update a single summary file for all runs
    summary_filename = os.path.join(output_folder, "SummaryMetrics_Irrelevant_25_50_75_New.csv")
    summary_data = {**{'csv_name': csv_name}, **summary}
    summary_data_df = pd.DataFrame([summary_data])

    if os.path.exists(summary_filename):
        all_summaries_df = pd.read_csv(summary_filename)
        all_summaries_df = all_summaries_df[all_summaries_df.csv_name != csv_name]
        all_summaries_df = pd.concat([all_summaries_df, summary_data_df], ignore_index=True)
        all_summaries_df.to_csv(summary_filename, index=False)
    else:
        summary_data_df.to_csv(summary_filename, index=False)
    print(f"Summary metrics updated in {summary_filename}\n")


def main():
    parser = argparse.ArgumentParser(description="Evaluate model answers in CSV files using similarity and Gemini.")
    parser.add_argument("-i", "--input", required=True, help="Path to a single input CSV file or a folder containing CSV files.")
    parser.add_argument("-o", "--output", required=True, help="Path to the output folder where results will be saved.")
    
    args = parser.parse_args()
    
    input_path = args.input
    output_path = args.output
    
    # Ensure output directory exists
    os.makedirs(output_path, exist_ok=True)
    
    if os.path.isdir(input_path):
        print(f"Processing all CSV files in folder: {input_path}")
        for filename in os.listdir(input_path):
            if filename.endswith('.csv'):
                file_path = os.path.join(input_path, filename)
                compute_metrics(file_path, output_path)
    elif os.path.isfile(input_path) and input_path.endswith('.csv'):
        print(f"Processing single file: {input_path}")
        compute_metrics(input_path, output_path)
    else:
        print(f"Error: Input path '{input_path}' is not a valid CSV file or directory.")

if __name__ == "__main__":
    main()