import re
import os
import json
import logging
import argparse
import numpy as np
import pandas as pd
from tqdm import tqdm
from ollama import chat
from typing import Dict, List, Optional

def get_config():
    parser = argparse.ArgumentParser(description="LLM Judge Bias Analysis")
    parser.add_argument("--bias-type", default="position", choices=["position"],
                       help="Type of bias to analyze (default: position)")
    parser.add_argument("--num-trials", type=int, default=3,
                       help="Number of trials per question (default: 3)")
    parser.add_argument("--max-workers", type=int, default=24,
                       help="Maximum number of workers (default: 24)")
    parser.add_argument("--model-name", default="llama3:latest",
                       help="LLM model name (default: llama3:latest)")
    parser.add_argument("--raw-data-path", 
                       default="./llm_as_judge_biased_data/original_data/raw.json",
                       help="Path to raw data file")
    parser.add_argument("--log-dir", default="llm_as_judge_biased_logs",
                       help="Directory for log files")
    parser.add_argument("--output-dir", default="llm_as_judge_biased_results",
                       help="Directory for output files")
    
    args = parser.parse_args()
    
    return {
        "bias_type": args.bias_type,
        "num_of_trials": args.num_trials,
        "max_workers": args.max_workers,
        "model_name": args.model_name,
        "raw_data_path": args.raw_data_path,
        "log_dir": args.log_dir,
        "output_dir": args.output_dir
    }

class LLMJudge:
    def __init__(self, logger, model_name: str, max_workers: int):
        self.logger = logger
        self.model_name = model_name
        self.max_workers = max_workers
        self._test_connection()

    def _test_connection(self):
        try:
            chat(model=self.model_name, messages=[{"role": "user", "content": "Test connection"}])
            self.logger.info(f"Connected to Ollama with model {self.model_name}")
        except Exception as e:
            self.logger.error(f"Failed to connect to Ollama: {str(e)}")
            raise

    def create_evaluation_prompt(self, question: str, response1: str, response2: str) -> str:
        if not all([question, response1, response2]):
            raise ValueError("Question and responses must be non-empty strings")
        
        return f"""You are an expert evaluator comparing two AI-generated responses to a user question.

QUESTION: "{question}"

RESPONSE_1: "{response1}"

RESPONSE_2: "{response2}"

Evaluate both responses based on relevance, accuracy, coherence, clarity, and completeness.

INSTRUCTIONS:
1. You MUST follow the EXACT output format specified below
2. Assign a score from 0-10 for each response
3. Provide brief reasoning for each score
4. Declare a winner or a tie

Your evaluation MUST follow this EXACT format with no deviations:

SCORE_1: number between 0-10
REASONING_1: brief explanation
SCORE_2: number between 0-10
REASONING_2: brief explanation
WINNER: must be exactly one of: Response 1, Response 2, or Tie
"""

    def evaluate_responses(self, question: str, response1: str, response2: str) -> Dict[str, any]:
        prompt = self.create_evaluation_prompt(question, response1, response2)
        try:
            response = chat(
                model=self.model_name,
                messages=[{"role": "user", "content": prompt}],
                options={"temperature": 0.5}
            )
            eval_text = response['message']['content']
            self.logger.debug(f"Raw evaluation for '{question}': {eval_text}")
            return self.parse_evaluation(eval_text)
        except Exception as e:
            self.logger.error(f"Evaluation failed for '{question}': {str(e)}")
            return {"score_1": 0, "score_2": 0, "winner": "Error"}

    def parse_evaluation(self, evaluation: str) -> Dict[str, any]:
        try:
            score_1 = int(re.search(r"SCORE_1:\s*(\d+)", evaluation).group(1) or 0)
            score_2 = int(re.search(r"SCORE_2:\s*(\d+)", evaluation).group(1) or 0)
            winner = re.search(r"WINNER:\s*(Response 1|Response 2|Tie)", evaluation).group(1) or "Error"
            
            score_1 = min(max(score_1, 0), 10)
            score_2 = min(max(score_2, 0), 10)
            
            return {"score_1": score_1, "score_2": score_2, "winner": winner}
        except Exception as e:
            self.logger.warning(f"Parsing failed: {str(e)}. Returning default.")
            return {"score_1": 0, "score_2": 0, "winner": "Error"}

def setup_logging(config: Dict) -> logging.Logger:
    os.makedirs(config["log_dir"], exist_ok=True)
    os.makedirs(config["output_dir"], exist_ok=True)
    
    log_file = os.path.join(config["log_dir"], f"{config['bias_type']}_bias.log")

    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s - %(levelname)s - %(message)s",
        handlers=[logging.FileHandler(log_file), logging.StreamHandler()]
    )
    return logging.getLogger(__name__)

def load_raw_data(file_path: str) -> Dict:
    raw_data = {}
    try:
        with open(file_path, 'r') as f:
            for line in f:
                data = json.loads(line)
                raw_data[data["question_id"]] = data
        return raw_data
    except Exception as e:
        raise Exception(f"Failed to load raw data: {str(e)}")

def save_results(df: pd.DataFrame, config: Dict):
    output_dir = config["output_dir"]
    
    # Save DataFrame
    df.to_csv(os.path.join(output_dir, f"{config['bias_type']}_bias_analysis.csv"), index=False)

def calculate_position_bias(eval_collection_with_trials):
    # Initialize dictionaries to track per-trial metrics
    per_trial_metrics = {}
    
    # Flatten the trials for each question-answer pair and group by trial
    for item_trials in eval_collection_with_trials:
        for trial in item_trials:
            trial_idx = trial["trial"]
            orig_eval = trial["original_ordering"]
            swapped_eval = trial["swapped_ordering"]
            
            if trial_idx not in per_trial_metrics:
                per_trial_metrics[trial_idx] = {"position_1_wins": 0, "total_decisions": 0}
            
            # Skip error cases
            if orig_eval["winner"] == "Error" or swapped_eval["winner"] == "Error":
                continue
            
            # Count non-tie decisions in original ordering
            if orig_eval["winner"] != "Tie":
                per_trial_metrics[trial_idx]["total_decisions"] += 1
                if orig_eval["winner"] == "Response 1":
                    per_trial_metrics[trial_idx]["position_1_wins"] += 1
            
            # Count non-tie decisions in swapped ordering
            if swapped_eval["winner"] != "Tie":
                per_trial_metrics[trial_idx]["total_decisions"] += 1
                if swapped_eval["winner"] == "Response 1":
                    per_trial_metrics[trial_idx]["position_1_wins"] += 1
    
    # Calculate per-trial position bias and overall averages
    trial_results = {}
    total_position_1_wins = 0
    total_decisions = 0
    
    for trial_idx, metrics in per_trial_metrics.items():
        position_1_wins = metrics["position_1_wins"]
        total_decisions_trial = metrics["total_decisions"]
        position_bias = position_1_wins / total_decisions_trial if total_decisions_trial > 0 else 0.5
        bias_magnitude = abs(position_bias - 0.5)
        
        trial_results[trial_idx] = {
            "position_1_win_rate": position_bias,
            "bias_magnitude": bias_magnitude,
            "total_decisions": total_decisions_trial,
            "position_1_wins": position_1_wins
        }
        
        total_position_1_wins += position_1_wins
        total_decisions += total_decisions_trial
    
    # Calculate averages across trials
    avg_position_1_win_rate = total_position_1_wins / total_decisions if total_decisions > 0 else 0.5
    avg_bias_magnitude = abs(avg_position_1_win_rate - 0.5)
    avg_per_trial_bias = np.mean([res["position_1_win_rate"] for res in trial_results.values()]) if trial_results else 0.5
    avg_per_trial_bias_magnitude = np.mean([res["bias_magnitude"] for res in trial_results.values()]) if trial_results else 0.0
    
    return {
        "per_trial": trial_results,
        "average_across_trials": {
            "position_1_win_rate": avg_position_1_win_rate,
            "bias_magnitude": avg_bias_magnitude,
            "avg_per_trial_position_1_win_rate": avg_per_trial_bias,
            "avg_per_trial_bias_magnitude": avg_per_trial_bias_magnitude,
            "total_decisions": total_decisions,
            "total_position_1_wins": total_position_1_wins
        }
    }

def analyze_consistency(eval_collection_with_trials):
    # Initialize dictionaries to track per-trial metrics
    per_trial_metrics = {}
    
    # Flatten the trials for each question-answer pair and group by trial
    for item_trials in eval_collection_with_trials:
        for trial in item_trials:
            trial_idx = trial["trial"]
            orig_eval = trial["original_ordering"]
            swapped_eval = trial["swapped_ordering"]
            
            if trial_idx not in per_trial_metrics:
                per_trial_metrics[trial_idx] = {"total_valid_pairs": 0, "consistent_pairs": 0}
            
            # Skip error cases
            if orig_eval["winner"] == "Error" or swapped_eval["winner"] == "Error":
                continue
            
            per_trial_metrics[trial_idx]["total_valid_pairs"] += 1
            
            # Check consistency
            if (orig_eval["winner"] == "Response 1" and swapped_eval["winner"] == "Response 2") or \
               (orig_eval["winner"] == "Response 2" and swapped_eval["winner"] == "Response 1") or \
               (orig_eval["winner"] == "Tie" and swapped_eval["winner"] == "Tie"):
                per_trial_metrics[trial_idx]["consistent_pairs"] += 1
    
    # Calculate per-trial consistency and overall averages
    trial_results = {}
    total_valid_pairs = 0
    total_consistent_pairs = 0
    
    for trial_idx, metrics in per_trial_metrics.items():
        total_pairs = metrics["total_valid_pairs"]
        consistent_pairs = metrics["consistent_pairs"]
        consistency_rate = consistent_pairs / total_pairs if total_pairs > 0 else 0.0
        
        trial_results[trial_idx] = {
            "consistency_rate": consistency_rate,
            "consistent_pairs": consistent_pairs,
            "total_valid_pairs": total_pairs
        }
        
        total_valid_pairs += total_pairs
        total_consistent_pairs += consistent_pairs
    
    # Calculate averages across trials
    avg_consistency_rate = total_consistent_pairs / total_valid_pairs if total_valid_pairs > 0 else 0.0
    avg_per_trial_consistency = np.mean([res["consistency_rate"] for res in trial_results.values()]) if trial_results else 0.0
    
    return {
        "per_trial": trial_results,
        "average_across_trials": {
            "consistency_rate": avg_consistency_rate,
            "avg_per_trial_consistency_rate": avg_per_trial_consistency,
            "total_valid_pairs": total_valid_pairs,
            "total_consistent_pairs": total_consistent_pairs
        }
    }

def produce_report(eval_collection_with_trials, logger):
    def is_consistent(orig_winner, swapped_winner):
        if orig_winner == "Response 1" and swapped_winner == "Response 2":
            return True
        elif orig_winner == "Response 2" and swapped_winner == "Response 1":
            return True
        elif orig_winner == "Tie" and swapped_winner == "Tie":
            return True
        return False
    
    # Flatten and prepare data for visualization
    flat_data = []
    for qa_idx, item_trials in enumerate(eval_collection_with_trials):
        for trial_idx, trial in enumerate(item_trials):
            orig_eval = trial["original_ordering"]
            swapped_eval = trial["swapped_ordering"]
            if orig_eval["winner"] != "Error" and swapped_eval["winner"] != "Error":
                flat_data.append({
                    "question_idx": qa_idx,
                    "trial_idx": trial["trial"],
                    "orig_winner": orig_eval["winner"],
                    "swapped_winner": swapped_eval["winner"],
                    "consistent": is_consistent(orig_eval["winner"], swapped_eval["winner"])
                })
    
    results = pd.DataFrame(flat_data)
    
    # Calculate metrics
    position_bias_metrics = calculate_position_bias(eval_collection_with_trials)
    consistency_metrics = analyze_consistency(eval_collection_with_trials)
    
    # Log overall averages
    logger.info("=== Overall Averages Across Trials ===")
    logger.info(f"Average Position 1 win rate (0.5 = no bias): {position_bias_metrics['average_across_trials']['position_1_win_rate']:.4f}")
    logger.info(f"Average Position bias magnitude: {position_bias_metrics['average_across_trials']['bias_magnitude']:.4f}")
    logger.info(f"Average per-trial Position 1 win rate: {position_bias_metrics['average_across_trials']['avg_per_trial_position_1_win_rate']:.4f}")
    logger.info(f"Average per-trial bias magnitude: {position_bias_metrics['average_across_trials']['avg_per_trial_bias_magnitude']:.4f}")
    logger.info(f"Total Position 1 wins: {position_bias_metrics['average_across_trials']['total_position_1_wins']} out of {position_bias_metrics['average_across_trials']['total_decisions']} decisions")
    logger.info(f"Average Consistency rate: {consistency_metrics['average_across_trials']['consistency_rate']:.4f}")
    logger.info(f"Average per-trial Consistency rate: {consistency_metrics['average_across_trials']['avg_per_trial_consistency_rate']:.4f}")
    logger.info(f"Total Consistent pairs: {consistency_metrics['average_across_trials']['total_consistent_pairs']} out of {consistency_metrics['average_across_trials']['total_valid_pairs']}")
    
    # Log per-trial details
    logger.info("\n=== Per-Trial Details ===")
    logger.info("Position 1 win rate by trial:")
    for trial_idx, metrics in position_bias_metrics["per_trial"].items():
        logger.info(f"Trial {trial_idx}: {metrics['position_1_win_rate']:.4f} (Wins: {metrics['position_1_wins']}/{metrics['total_decisions']})")
    
    logger.info("\nConsistency rate by trial:")
    for trial_idx, metrics in consistency_metrics["per_trial"].items():
        logger.info(f"Trial {trial_idx}: {metrics['consistency_rate']:.4f} (Consistent: {metrics['consistent_pairs']}/{metrics['total_valid_pairs']})")
    
    return results

def main():
    
    # Setup
    CONFIG = get_config()
    logger = setup_logging(CONFIG)
    logger.info(f"Starting {CONFIG['bias_type']} Bias Analysis")
    
    # Load data
    raw_data = load_raw_data(CONFIG["raw_data_path"])
    
    # Initialize judge
    judge = LLMJudge(logger=logger, model_name=CONFIG["model_name"], max_workers=CONFIG["max_workers"])
    
    # Process evaluations
    bias_eval_collection = []
    for i, (question_id, data) in tqdm(enumerate(raw_data.items()), total=len(raw_data), desc="Processing questions"):
            
        question = data["question"]
        answer_a = data["answers"]["answer1"]["answer"]
        answer_b = data["answers"]["answer2"]["answer"]
        
        trial_collection = []
        for trial in range(CONFIG["num_of_trials"]):
            try:
                orig_eval = judge.evaluate_responses(question, answer_a, answer_b)
                swapped_eval = judge.evaluate_responses(question, answer_b, answer_a)
                
                trial_collection.append({
                    "trial": trial,
                    "original_ordering": orig_eval,
                    "swapped_ordering": swapped_eval,
                    "question_id": question_id
                })
            except Exception as e:
                logger.error(f"Trial {trial} failed for question {question_id}: {str(e)}")
        
        bias_eval_collection.append(trial_collection)
    
    # Calculate metrics
    results = produce_report(bias_eval_collection, logger)
    
    # Save results
    save_results(results, CONFIG)
    logger.info("Analysis completed and results saved")

if __name__ == "__main__":
    main()