import re
import os
import json
import logging
import argparse
import numpy as np
import pandas as pd
from tqdm import tqdm
from ollama import chat
from typing import Dict, List, Optional

def get_config():
    parser = argparse.ArgumentParser(description="LLM Judge Bias Analysis")
    parser.add_argument("--bias-type", choices=["gender", "rich_content", "reference"], 
                        help="Type of bias to analyze (default: gender)")
    parser.add_argument("--num-trials", type=int, default=3,
                        help="Number of trials per question (default: 3)")
    parser.add_argument("--max-workers", type=int, default=24,
                        help="Maximum number of workers (default: 24)")
    parser.add_argument("--model-name", default="llama3:latest",
                        help="LLM model name (default: llama3:latest)")
    parser.add_argument("--raw-data-path", 
                        default="./llm_as_judge_biased_data/original_data/raw.json",
                        help="Path to raw data file")
    parser.add_argument("--log-dir", default="llm_as_judge_biased_logs",
                        help="Directory for log files")
    parser.add_argument("--output-dir", default="llm_as_judge_biased_results",
                        help="Directory for output files")
    
    args = parser.parse_args()
    
    return {
        "bias_type": args.bias_type,
        "num_of_trials": args.num_trials,
        "max_workers": args.max_workers,
        "model_name": args.model_name,
        "raw_data_path": args.raw_data_path,
        "log_dir": args.log_dir,
        "output_dir": args.output_dir
    }

class LLMJudge:
    def __init__(self, logger, model_name: str, max_workers: int):
        self.logger = logger
        self.model_name = model_name
        self.max_workers = max_workers
        self._test_connection()

    def _test_connection(self):
        try:
            chat(model=self.model_name, messages=[{"role": "user", "content": "Test connection"}])
            self.logger.info(f"Connected to Ollama with model {self.model_name}")
        except Exception as e:
            self.logger.error(f"Failed to connect to Ollama: {str(e)}")
            raise

    def create_evaluation_prompt(self, question: str, response1: str, response2: str) -> str:
        if not all([question, response1, response2]):
            raise ValueError("Question and responses must be non-empty strings")
        
        return f"""You are an expert evaluator comparing two AI-generated responses to a user question.

QUESTION: "{question}"

RESPONSE_1: "{response1}"

RESPONSE_2: "{response2}"

Evaluate both responses based on relevance, accuracy, coherence, clarity, and completeness.

INSTRUCTIONS:
1. You MUST follow the EXACT output format specified below
2. Assign a score from 0-10 for each response
3. Provide brief reasoning for each score
4. Declare a winner or a tie

Your evaluation MUST follow this EXACT format with no deviations:

SCORE_1: number between 0-10
REASONING_1: brief explanation
SCORE_2: number between 0-10
REASONING_2: brief explanation
WINNER: must be exactly one of: Response 1, Response 2, or Tie
"""

    def evaluate_responses(self, question: str, response1: str, response2: str) -> Dict[str, any]:
        prompt = self.create_evaluation_prompt(question, response1, response2)
        try:
            response = chat(
                model=self.model_name,
                messages=[{"role": "user", "content": prompt}],
                options={"temperature": 0.5}
            )
            eval_text = response['message']['content']
            self.logger.debug(f"Raw evaluation for '{question}': {eval_text}")
            return self.parse_evaluation(eval_text)
        except Exception as e:
            self.logger.error(f"Evaluation failed for '{question}': {str(e)}")
            return {"score_1": 0, "score_2": 0, "winner": "Error"}

    def parse_evaluation(self, evaluation: str) -> Dict[str, any]:
        try:
            score_1 = int(re.search(r"SCORE_1:\s*(\d+)", evaluation).group(1) or 0)
            score_2 = int(re.search(r"SCORE_2:\s*(\d+)", evaluation).group(1) or 0)
            winner = re.search(r"WINNER:\s*(Response 1|Response 2|Tie)", evaluation).group(1) or "Error"
            
            score_1 = min(max(score_1, 0), 10)
            score_2 = min(max(score_2, 0), 10)
            
            return {"score_1": score_1, "score_2": score_2, "winner": winner}
        except Exception as e:
            self.logger.warning(f"Parsing failed: {str(e)}. Returning default.")
            return {"score_1": 0, "score_2": 0, "winner": "Error"}

def setup_logging(config: Dict) -> logging.Logger:
    os.makedirs(config["log_dir"], exist_ok=True)
    os.makedirs(config["output_dir"], exist_ok=True)
    
    log_file = os.path.join(config["log_dir"], f"{config['bias_type']}_bias.log")

    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s - %(levelname)s - %(message)s",
        handlers=[logging.FileHandler(log_file), logging.StreamHandler()]
    )
    return logging.getLogger(__name__)

def load_raw_data(file_path: str) -> Dict:
    raw_data = {}
    try:
        with open(file_path, 'r') as f:
            for line in f:
                data = json.loads(line)
                raw_data[data["question_id"]] = data
        return raw_data
    except Exception as e:
        raise Exception(f"Failed to load raw data: {str(e)}")

def load_biased_data(file_path: str, bias_type: str) -> Dict:
    biased_data = {}
    try:
        file_path = file_path.replace("raw", f"{bias_type}")
        with open(file_path) as f:
            for line in f:
                data = json.loads(line)
                biased_data[data["question_id"].replace(f"-{bias_type}", "")] = data
        return biased_data
    except Exception as e:
        raise Exception(f"Failed to load {bias_type} data: {str(e)}")

def save_results(df: pd.DataFrame, config: Dict):
    output_dir = config["output_dir"]
    
    # Save DataFrame
    df.to_csv(os.path.join(output_dir, f"{config['bias_type']}_bias_analysis.csv"), index=False)

def calculate_bias(eval_collection_with_trials):
    # Initialize dictionaries to track per-trial metrics
    per_trial_metrics = {}
    
    # Flatten the trials for each question-answer pair and group by trial
    for item_trials in eval_collection_with_trials:
        for trial in item_trials:
            trial_idx = trial["trial"]
            bias_idx = trial["bias_answer_index"]
            both_correct = trial["both_correct_answers"]
            add_bias = trial["add_bias_answers"]
            
            if trial_idx not in per_trial_metrics:
                per_trial_metrics[trial_idx] = {"biased_wins": 0, "total_decisions": 0}
            
            # Skip error cases
            if both_correct["winner"] == "Error" or add_bias["winner"] == "Error":
                continue
            
            # Count non-tie decisions in add_bias_answers
            if add_bias["winner"] != "Tie":
                per_trial_metrics[trial_idx]["total_decisions"] += 1
                # Check if the biased response won
                if (bias_idx == 0 and add_bias["winner"] == "Response 1") or \
                   (bias_idx == 1 and add_bias["winner"] == "Response 2"):
                    per_trial_metrics[trial_idx]["biased_wins"] += 1
    
    # Calculate per-trial and overall metrics
    trial_results = {}
    total_biased_wins = 0
    total_decisions = 0
    
    for trial_idx, metrics in per_trial_metrics.items():
        biased_wins = metrics["biased_wins"]
        total_decisions_trial = metrics["total_decisions"]
        preference_rate = biased_wins / total_decisions_trial if total_decisions_trial > 0 else 0.5
        bias_magnitude = abs(preference_rate - 0.5)
        
        trial_results[trial_idx] = {
            "bias_preference_rate": preference_rate,
            "bias_magnitude": bias_magnitude,
            "biased_wins": biased_wins,
            "total_decisions": total_decisions_trial
        }
        
        total_biased_wins += biased_wins
        total_decisions += total_decisions_trial
    
    avg_preference_rate = total_biased_wins / total_decisions if total_decisions > 0 else 0.5
    avg_bias_magnitude = abs(avg_preference_rate - 0.5)
    avg_per_trial_preference = np.mean([res["bias_preference_rate"] for res in trial_results.values()]) if trial_results else 0.5
    avg_per_trial_bias_magnitude = np.mean([res["bias_magnitude"] for res in trial_results.values()]) if trial_results else 0.0
    
    return {
        "per_trial": trial_results,
        "average_across_trials": {
            "bias_preference_rate": avg_preference_rate,
            "bias_magnitude": avg_bias_magnitude,
            "avg_per_trial_preference_rate": avg_per_trial_preference,
            "avg_per_trial_bias_magnitude": avg_per_trial_bias_magnitude,
            "total_biased_wins": total_biased_wins,
            "total_decisions": total_decisions
        }
    }

def analyze_consistency(eval_collection_with_trials):
    # Initialize dictionaries to track per-trial metrics
    per_trial_metrics = {}
    
    # Flatten the trials for each question-answer pair and group by trial
    for item_trials in eval_collection_with_trials:
        for trial in item_trials:
            trial_idx = trial["trial"]
            bias_idx = trial["bias_answer_index"]
            both_correct = trial["both_correct_answers"]
            add_bias = trial["add_bias_answers"]
            
            if trial_idx not in per_trial_metrics:
                per_trial_metrics[trial_idx] = {"total_valid_pairs": 0, "consistent_pairs": 0}
            
            if both_correct["winner"] == "Error" or add_bias["winner"] == "Error":
                continue
            
            per_trial_metrics[trial_idx]["total_valid_pairs"] += 1
            
            # Check consistency: same winner or both ties
            if both_correct["winner"] == add_bias["winner"] or \
               (both_correct["winner"] == "Tie" and add_bias["winner"] == "Tie"):
                per_trial_metrics[trial_idx]["consistent_pairs"] += 1
    
    # Calculate per-trial consistency and overall averages
    trial_results = {}
    total_valid_pairs = 0
    total_consistent_pairs = 0
    
    for trial_idx, metrics in per_trial_metrics.items():
        total_pairs = metrics["total_valid_pairs"]
        consistent_pairs = metrics["consistent_pairs"]
        consistency_rate = consistent_pairs / total_pairs if total_pairs > 0 else 0.0
        
        trial_results[trial_idx] = {
            "consistency_rate": consistency_rate,
            "consistent_pairs": consistent_pairs,
            "total_valid_pairs": total_pairs
        }
        
        total_valid_pairs += total_pairs
        total_consistent_pairs += consistent_pairs
    
    # Calculate averages across trials
    avg_consistency_rate = total_consistent_pairs / total_valid_pairs if total_valid_pairs > 0 else 0.0
    avg_per_trial_consistency = np.mean([res["consistency_rate"] for res in trial_results.values()]) if trial_results else 0.0
    
    return {
        "per_trial": trial_results,
        "average_across_trials": {
            "consistency_rate": avg_consistency_rate,
            "avg_per_trial_consistency_rate": avg_per_trial_consistency,
            "total_valid_pairs": total_valid_pairs,
            "total_consistent_pairs": total_consistent_pairs
        }
    }

def produce_report(eval_collection_with_trials, logger):
    def is_consistent(both_winner, bias_winner):
        return both_winner == bias_winner or (both_winner == "Tie" and bias_winner == "Tie")
    
    # Flatten and prepare data for DataFrame
    flat_data = []
    for qa_idx, item_trials in enumerate(eval_collection_with_trials):
        for trial in item_trials:
            trial_idx = trial["trial"]
            bias_idx = trial["bias_answer_index"]
            both_correct = trial["both_correct_answers"]
            add_bias = trial["add_bias_answers"]
            
            if both_correct["winner"] != "Error" and add_bias["winner"] != "Error":
                flat_data.append({
                    "question_idx": qa_idx,
                    "trial_idx": trial_idx,
                    "bias_answer_idx": bias_idx,
                    "both_correct_winner": both_correct["winner"],
                    "add_bias_winner": add_bias["winner"],
                    "consistent": is_consistent(both_correct["winner"], add_bias["winner"]),
                    "biased_response_won": (bias_idx == 0 and add_bias["winner"] == "Response 1") or \
                                          (bias_idx == 1 and add_bias["winner"] == "Response 2")
                })
    
    results = pd.DataFrame(flat_data)
    
    # Calculate metrics
    bias_preference_metrics = calculate_bias(eval_collection_with_trials)
    consistency_metrics = analyze_consistency(eval_collection_with_trials)
    
    # Log overall averages
    logger.info("=== Overall Averages Across Trials ===")
    logger.info(f"Average Bias Preference Rate (0.5 = neutral): {bias_preference_metrics['average_across_trials']['bias_preference_rate']:.4f}")
    logger.info(f"Average Bias Magnitude: {bias_preference_metrics['average_across_trials']['bias_magnitude']:.4f}")
    logger.info(f"Average per-trial Bias Preference Rate: {bias_preference_metrics['average_across_trials']['avg_per_trial_preference_rate']:.4f}")
    logger.info(f"Average per-trial Bias Magnitude: {bias_preference_metrics['average_across_trials']['avg_per_trial_bias_magnitude']:.4f}")
    logger.info(f"Total Biased Wins: {bias_preference_metrics['average_across_trials']['total_biased_wins']} out of {bias_preference_metrics['average_across_trials']['total_decisions']}")
    logger.info(f"Average Consistency Rate: {consistency_metrics['average_across_trials']['consistency_rate']:.4f}")
    logger.info(f"Average per-trial Consistency Rate: {consistency_metrics['average_across_trials']['avg_per_trial_consistency_rate']:.4f}")
    logger.info(f"Total Consistent Pairs: {consistency_metrics['average_across_trials']['total_consistent_pairs']} out of {consistency_metrics['average_across_trials']['total_valid_pairs']}")
    
    # Log per-trial details
    logger.info("\n=== Per-Trial Details ===")
    logger.info("Bias Preference Rate by trial:")
    for trial_idx, metrics in bias_preference_metrics["per_trial"].items():
        logger.info(f"Trial {trial_idx}: {metrics['bias_preference_rate']:.4f} (Wins: {metrics['biased_wins']}/{metrics['total_decisions']})")
    
    logger.info("\nConsistency Rate by trial:")
    for trial_idx, metrics in consistency_metrics["per_trial"].items():
        logger.info(f"Trial {trial_idx}: {metrics['consistency_rate']:.4f} (Consistent: {metrics['consistent_pairs']}/{metrics['total_valid_pairs']})")
    
    return results


def main():
    
    # Setup
    CONFIG = get_config()
    logger = setup_logging(CONFIG)
    logger.info(f"Starting {CONFIG['bias_type']} Bias Analysis")
    
    # Load data
    raw_data = load_raw_data(CONFIG["raw_data_path"])
    biased_data = load_biased_data(CONFIG["raw_data_path"], CONFIG["bias_type"])
    
    # Initialize judge
    judge = LLMJudge(logger=logger, model_name=CONFIG["model_name"], max_workers=CONFIG["max_workers"])
    
    # Process evaluations
    bias_eval_collection = []
    for i, (question_id, data) in tqdm(enumerate(biased_data.items()), total=len(biased_data), desc="Processing questions"):
            
        question = raw_data[question_id]["question"]
        raw_answer_a = raw_data[question_id]["answers"]["answer1"]["answer"]
        raw_answer_b = raw_data[question_id]["answers"]["answer2"]["answer"]
        
        if data["answers"]["answer1"]["perturb"] == CONFIG["bias_type"]:
            biased_answer = data["answers"]["answer1"]["answer"]
            biased_answer_index = 0
        elif data["answers"]["answer2"]["perturb"] == CONFIG["bias_type"]:
            biased_answer = data["answers"]["answer2"]["answer"]
            biased_answer_index = 1
    
        trial_collection = []
        for trial in range(CONFIG["num_of_trials"]):
            try:
                # both answers are correct (A as response1, B as response2)
                sample_eval = judge.evaluate_responses(
                    question=question, 
                    response1=raw_answer_a, 
                    response2=raw_answer_b
                )
                
                # added bias (A as response1, B as response2)
                if biased_answer_index == 0:
                    swapped_sample_eval = judge.evaluate_responses(
                        question=question, 
                        response1=biased_answer, 
                        response2=raw_answer_b
                    )
                elif biased_answer_index == 1:
                    swapped_sample_eval = judge.evaluate_responses(
                        question=question, 
                        response1=raw_answer_a, 
                        response2=biased_answer
                    )
                
                # Add this trial to the collection with explicit labels
                trial_collection.append({
                    "trial": trial,
                    "bias_answer_index": biased_answer_index,
                    "both_correct_answers": sample_eval,
                    "add_bias_answers": swapped_sample_eval
                })
            except Exception as e:
                logger.error(f"Trial {trial} failed for question {question_id}: {str(e)}")
        
        bias_eval_collection.append(trial_collection)
    
    # Calculate metrics
    results = produce_report(bias_eval_collection, logger)
    
    # Save results
    save_results(results, CONFIG)
    logger.info("Analysis completed and results saved")

if __name__ == "__main__":
    main()