import os
import argparse
import pandas as pd
import numpy as np
import glob
import asyncio
from datetime import datetime
from tqdm import tqdm
from litellm import completion
from dotenv import load_dotenv

load_dotenv()

import json

async def check_coverage_async(gt_claim, pred_claims, model="gpt-5"):
    if not pred_claims:
        return False, []

    # pred_text = "\n".join([f"- {c}" for c in pred_claims])
    pred_text = "\n".join([f"[{i}] {c}" for i, c in enumerate(pred_claims)])
    
    prompt = f"""You are an expert fact-checking evaluator.
Your task is to determine if the "Ground Truth Claim" is fully covered by the "Candidate Claims".

Ground Truth Claim:
{gt_claim}

Candidate Claims:
{pred_text}

Instructions:
1. Determine if the core information of the Ground Truth Claim is present in the Candidate Claims.
2. Ignore minor formatting differences.
3. If covered, identify which Candidate Claim indices (0-based) contribute to the coverage.

Return a JSON object with:
- "covered": boolean (true if covered, false otherwise)
- "matched_indices": list of integers (indices of candidates that cover the GT)
"""

    try:
        response = await asyncio.to_thread(
            completion,
            model=model,
            messages=[{"role": "user", "content": prompt}],
            # temperature=0.0,
            response_format={"type": "json_object"}
        )
        content = response.choices[0].message.content
        data = json.loads(content)
        return data.get("covered", False), data.get("matched_indices", [])
    except Exception as e:
        print(f"Error in LLM call: {e}")
        return False, []

async def evaluate_batch(batch, model):
    tasks = []
    for gt_claim, pred_claims, gt_id in batch:
        tasks.append(check_coverage_async(gt_claim, pred_claims, model))
    
    results = await asyncio.gather(*tasks)
    return results

def evaluate_llm_judge(pred_filename, pred_df, gt_df, model='gpt-5', batch_size=20, group_by='sentence'):
    print(f"Evaluating with LLM Judge: {model}...")
    
    # Filter valid classes
    valid_classes = ['A', 'B', 'C', 'D', 'E', 'F']
    pred_df = pred_df[pred_df['class'].isin(valid_classes)].copy()
    gt_df = gt_df[gt_df['class'].isin(valid_classes)].copy()

    # Normalize
    gt_df['claim'] = gt_df['claim'].astype(str).str.strip()
    pred_df['claim'] = pred_df['claim'].astype(str).str.strip()

    # Group predictions
    eval_items = [] # (gt_claim, pred_claim_texts, pred_claim_classes, gt_id, gt_class)
    gt_id_map = {}
    gt_counter = 0

    print(f"Preparing evaluation items (Group by: {group_by})...")
    # Iterate over GT claims
    for idx, row in gt_df.iterrows():
        sample = row['sample']
        position = row['position']
        gt_claim = row['claim']
        gt_class = row['class']
        
        key = None
        if group_by == 'sentence':
            key = (sample, position)
        elif group_by == 'paragraph':
            paragraph = str(position).split('.')[0] if '.' in str(position) else str(position)
            key = (sample, paragraph)
    
    # RE-IMPLEMENT GROUPING TO INCLUDE CLASSES
    pred_records = pred_df.to_dict('records')
    pred_grouped_full = {} # key -> list of {claim, class}
    
    for row in pred_records:
        sample = row['sample']
        pos = row['position']
        
        key = None
        if group_by == 'sentence':
            key = (sample, pos)
        elif group_by == 'paragraph':
            paragraph = str(pos).split('.')[0] if '.' in str(pos) else str(pos)
            key = (sample, paragraph)
            
        if key not in pred_grouped_full:
            pred_grouped_full[key] = []
        pred_grouped_full[key].append(row)
        
    
    # Iterate over GT claims
    for idx, row in gt_df.iterrows():
        sample = row['sample']
        position = row['position']
        gt_claim = row['claim']
        gt_class = row['class']
        
        key = None
        if group_by == 'sentence':
            key = (sample, position)
        elif group_by == 'paragraph':
            paragraph = str(position).split('.')[0] if '.' in str(position) else str(position)
            key = (sample, paragraph)
            
        candidates = pred_grouped_full.get(key, [])
        candidate_texts = [c['claim'] for c in candidates]
        candidate_classes = [c['class'] for c in candidates]
        
        gt_id = gt_counter
        gt_id_map[gt_id] = {
            'sample': sample,
            'position': position,
            'gt_claim': gt_claim,
            'gt_class': gt_class,
            'candidate_claims': candidate_texts,
            'candidate_classes': candidate_classes,
            'found': False,
            'class_match': False,
            'matched_indices': []
        }
        gt_counter += 1
        
        eval_items.append((gt_claim, candidate_texts, gt_id))

    total_gt = len(eval_items)
    print(f"Total Ground Truth Claims: {total_gt}")
    
    # Run in batches
    loop = asyncio.get_event_loop()
    
    num_batches = (len(eval_items) + batch_size - 1) // batch_size
    print(f"Running evaluation in {num_batches} batches...")
    
    for i in tqdm(range(num_batches)):
        batch_items = eval_items[i*batch_size : (i+1)*batch_size]
        # batch_items is list of (gt, preds, id)
        
        results = loop.run_until_complete(evaluate_batch(batch_items, model))
        
        for j, (covered, indices) in enumerate(results):
            gt_id = batch_items[j][2]
            gt_info = gt_id_map[gt_id]
            
            gt_info['found'] = covered
            gt_info['matched_indices'] = indices
            
            # Check class match
            class_match = False
            if covered and indices:
                # If any of the matched candidates has the same class, we count it as class match
                # Or should strictly ALL match? 
                # Let's say: Does the "semantic match" also have the "class match"?
                # If multiple semantic matches (splits), maybe they all should match?
                # Relaxation: If AT LEAST ONE matched candidate has the correct class -> True
                for idx in indices:
                    if idx < len(gt_info['candidate_classes']):
                        if gt_info['candidate_classes'][idx] == gt_info['gt_class']:
                            class_match = True
                            break
            
            gt_info['class_match'] = class_match

    # Aggregate results for Metrics
    metrics = {}
    
    total_pred = len(pred_df)
    
    tp_semantic = sum(1 for info in gt_id_map.values() if info['found'])
    tp_class = sum(1 for info in gt_id_map.values() if info['class_match'])
    
    # Metrics
    # Semantic
    precision_sem = tp_semantic / total_pred if total_pred > 0 else 0
    recall_sem = tp_semantic / total_gt if total_gt > 0 else 0
    f1_sem = 2 * (precision_sem * recall_sem) / (precision_sem + recall_sem) if (precision_sem + recall_sem) > 0 else 0
    
    # Class
    precision_cls = tp_class / total_pred if total_pred > 0 else 0 # Approximation: assumes 1 GT matches 1 Pred roughly
    recall_cls = tp_class / total_gt if total_gt > 0 else 0
    f1_cls = 2 * (precision_cls * recall_cls) / (precision_cls + recall_cls) if (precision_cls + recall_cls) > 0 else 0
    
    # Classification Accuracy (among semantic matches)
    acc_cls = tp_class / tp_semantic if tp_semantic > 0 else 0

    print("\n" + "="*50)
    print("LLM JUDGE EVALUATION RESULTS")
    print("="*50)
    print(f"Model: {model}")
    print(f"Group By: {group_by}")
    print(f"Total GT: {total_gt}")
    print(f"Total Pred: {total_pred}")
    print("-" * 30)
    print("SEMANTIC MATCH:")
    print(f"  Recall:    {recall_sem:.4f} ({tp_semantic}/{total_gt})")
    print(f"  Precision: {precision_sem:.4f} ({tp_semantic}/{total_pred}) [Approx]")
    print(f"  F1:        {f1_sem:.4f}")
    print("-" * 30)
    print("CLASS MATCH:")
    print(f"  Recall:    {recall_cls:.4f} ({tp_class}/{total_gt})")
    print(f"  Precision: {precision_cls:.4f} ({tp_class}/{total_pred}) [Approx]")
    print(f"  F1:        {f1_cls:.4f}")
    print(f"  Accuracy:  {acc_cls:.4f} (Given semantic match)")
    print("="*50)
    
    metrics = {
        "semantic": {"precision": precision_sem, "recall": recall_sem, "f1": f1_sem},
        "class": {"precision": precision_cls, "recall": recall_cls, "f1": f1_cls, "accuracy_conditional": acc_cls}
    }
    
    # Save Detailed Results
    details_list = []
    for gt_id, info in gt_id_map.items():
        matched_claims = [info['candidate_claims'][idx] for idx in info['matched_indices']] if info['matched_indices'] else []
        matched_classes = [info['candidate_classes'][idx] for idx in info['matched_indices']] if info['matched_indices'] else []
        
        details_list.append({
            "sample": info['sample'],
            "position": info['position'],
            "gt_claim": info['gt_claim'],
            "gt_class": info['gt_class'],
            "semantic_match": info['found'],
            "class_match": info['class_match'],
            "matched_claims": str(matched_claims),
            "matched_classes": str(matched_classes),
            "candidate_count": len(info['candidate_claims'])
        })
        
    df_details = pd.DataFrame(details_list)
    
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    output_prefix = os.path.splitext(pred_filename)[0]
    output_prefix = f"{output_prefix}_{model}_{group_by}_{timestamp}"
    
    # Save Excel
    df_details.to_excel(f"{output_prefix}_details.xlsx", index=False)
    print(f"\nSaved details to {output_prefix}_details.xlsx")
    
    # Save Metrics JSON
    with open(f"{output_prefix}_metrics.json", "w") as f:
        json.dump(metrics, f, indent=4)
    print(f"Saved metrics to {output_prefix}_metrics.json")
    
    return metrics

def main():
    parser = argparse.ArgumentParser(description="Evaluate claim extraction recall using LLM Judge.")
    parser.add_argument("--results_dir", type=str, default="results", help="Directory containing result files")
    parser.add_argument("--input_file", type=str, help="Specific input file to evaluate")
    parser.add_argument("--model", type=str, default="gpt-5", help="LLM model name")
    parser.add_argument("--batch_size", type=int, default=50, help="Batch size for async calls")
    parser.add_argument("--group_by", type=str, default="sentence", choices=["sentence", "paragraph"], help="Grouping level for evaluation")
    
    args = parser.parse_args()
    
    base_dir = "."
    samples = ["sample01"]
    
    all_gt_dfs = []
    for sample in samples:
        gt_file = os.path.join(base_dir, f"{sample}.xlsx")
        if os.path.exists(gt_file):
            df = pd.read_excel(gt_file)
            df['sample'] = sample
            all_gt_dfs.append(df)
        else:
            print(f"Warning: Ground truth file {gt_file} not found.")
            
    if not all_gt_dfs:
        print("No ground truth files found.")
        return

    combined_gt_df = pd.concat(all_gt_dfs, ignore_index=True)
    
    if args.input_file:
        latest_file = args.input_file
        if not os.path.exists(latest_file):
            print(f"Input file not found: {latest_file}")
            return
    else:
        pattern = os.path.join(args.results_dir, "combined_*.xlsx")
        pred_files = glob.glob(pattern)
        
        if not pred_files:
            print(f"No combined result files found in {args.results_dir}")
            return
            
        for pred_file in pred_files:
            print(f"Evaluating file: {pred_file}")
            
            pred_df = pd.read_excel(pred_file)
            evaluate_llm_judge(pred_file, pred_df, combined_gt_df, model=args.model, batch_size=args.batch_size, group_by=args.group_by)

if __name__ == "__main__":
    main()
