import json
import argparse
from vllm import LLM, SamplingParams
import re
import os
import json_repair
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

DIAGNOSE_TEMPLATE = """
You are an expert Data Quality Auditor.
Your task is to DIAGNOSE quality bottlenecks in a model response.
You do NOT rewrite, fix, or optimize the response.
Your feedback is diagnostic and NON-BINDING, used by an automated refinement system.

TASK
Given an (Instruction, Input, Response) triple:
1. Assign quality scores across four independent dimensions.
2. Identify the PRIMARY limiting factor that prevents the response from being near-perfect.
3. Provide a non-binding content hint describing what information appears to be missing or weak.

SCORING PRINCIPLES
- Each dimension is evaluated independently.
- Use the full 1-5 scale (5 = near-perfect).
- Score meanings: 2 = substandard effort; 3 = passable but limited.

SCORING RUBRIC
1. Instruction Compliance
   - 5: Fully follows instruction and all constraints.
   - 4: Minor looseness in interpreting vague constraints.
   - 3: Misses a minor constraint.
   - 2: Violates a hard constraint.
   - 1: Fails to address the task.

2. Information Richness (Focus: Depth, specificity, structure)
   - 5: Insightful, specific, includes reasoning, examples, or limitations.
   - 4: Adequate depth, at least one concrete supporting element.
   - 3: Correct but generic; surface-level explanation.
   - 2: Sparse or underdeveloped.
   - 1: Empty or meaningless.

3. Logical Soundness
   - 5: Premises are factually true AND reasoning is flawless.
   - 4: Premises are true, reasoning has minor imprecision.
   - 3: Premises are true, but reasoning steps are incomplete.
   - 2: The core factual premise is WRONG, even if the reasoning flows logically.
   - 1: Complete fabrication or nonsensical content.

4. Presentation Quality
   - 5: Clear structure, effective formatting.
   - 4: Readable but could be better organized.
   - 3: Wall-of-text but understandable.
   - 2: Formatting harms readability.
   - 1: Unreadable.

---

### INPUT
Instruction: {instruction}
Input Context: {input}
Model Response: {output}

---
### OUTPUT FORMAT (JSON ONLY)

{{
  "scores": {{
    "compliance": <int 1-5>,
    "richness": <int 1-5>,
    "soundness": <int 1-5>,
    "presentation": <int 1-5>
  }},
  "diagnostic_description": "A brief, non-binding description of what seems missing or weak. Required unless all scores are 5."
}}
"""


def get_triage_decision(scores):
    try:
        C = int(scores.get('compliance', 0))
        R = int(scores.get('richness', 0))
        P = int(scores.get('presentation', 0))
        S = int(scores.get('soundness', scores.get('logic', 0)))
    except:
        return "DROP", "Score Parsing Error"


    if C <= 3 or S <= 2 or R <= 2:
        return "DROP", "Hard Constraint Violation"

    if C == 5 and R >= 4 and S >= 4 and P >= 3:
        return "KEEP", "Elite Quality"

    return "FIX", "Reclamation Target"


def parse_response(response_text):
    text = response_text.strip()
    pattern = r'```(?:json)?\s*(\{.*?\})\s*```'
    match = re.search(pattern, text, re.DOTALL)
    if match:
        text = match.group(1)
        
    try:
        decoded_obj = json_repair.repair_json(text, return_objects=True)
        if isinstance(decoded_obj, list):
            decoded_obj = decoded_obj[0]
        if isinstance(decoded_obj, dict):
            if 'scores' not in decoded_obj:
                if 'compliance' in decoded_obj: 
                     decoded_obj = {'scores': decoded_obj}
                else:
                    return None
            return decoded_obj
    except Exception as e:
        return None
    return None

def load_data(input_path):
    data = []
    print(f"Loading data from {input_path}...")
    if not os.path.exists(input_path):
        raise FileNotFoundError(f"Input file not found: {input_path}")
        
    with open(input_path, 'r', encoding='utf-8') as f:
        if input_path.endswith('.jsonl'):
            for line in f:
                if line.strip():
                    data.append(json.loads(line))
        else:
            data = json.load(f)
    print(f"Loaded {len(data)} records from input.")
    return data

def get_processed_count(file_paths):
    count = 0
    for f_path in file_paths:
        if os.path.exists(f_path):
            with open(f_path, 'r', encoding='utf-8') as f:
                for line in f:
                    if line.strip():
                        count += 1
    return count

def plot_triage_stats(stats_dict, output_dir):
    print("Generating triage statistics plot...")
    
    categories = ['DROP', 'FIX', 'KEEP']
    counts = [stats_dict.get('DROP', 0), stats_dict.get('FIX', 0), stats_dict.get('KEEP', 0)]
    total = sum(counts)
    if total == 0:
        return

    sns.set_theme(style="whitegrid")
    plt.figure(figsize=(10, 6))
    colors = ['#c44e52', '#ccb974', '#55a868']
    
    bars = plt.bar(categories, counts, color=colors, alpha=0.8, width=0.6)
    
    for bar in bars:
        height = bar.get_height()
        percentage = (height / total) * 100
        plt.text(bar.get_x() + bar.get_width()/2., height,
                 f'{int(height)}\n({percentage:.1f}%)',
                 ha='center', va='bottom', fontweight='bold')

    plt.title('Triage Results: The Data Funnel', fontsize=15, pad=20)
    plt.ylabel('Sample Count', fontsize=12)
    plt.ylim(0, max(counts) * 1.15)
    
    save_path = os.path.join(output_dir, 'triage_stats.png')
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    print(f"Stats plot saved to: {save_path}")
    plt.close()

def main(args):
    all_data = load_data(args.input_path)
    
    output_files = [args.output_keep, args.output_fix, args.output_drop]
    processed_count = get_processed_count(output_files)
    
    if processed_count > 0:
        print(f"Detected {processed_count} already processed records. Resuming...")
        if processed_count < len(all_data):
            data_to_process = all_data[processed_count:]
        else:
            print("All data processed. Skipping inference.")
            data_to_process = []
    else:
        print("Starting fresh run...")
        data_to_process = all_data
        for f_path in output_files:
            os.makedirs(os.path.dirname(f_path), exist_ok=True)

    if data_to_process:
        prompts = []
        for item in data_to_process:
            prompt = DIAGNOSE_TEMPLATE.format(
                instruction=item['instruction'], 
                input=item.get('input', ''), 
                output=item['output']
            )
            prompts.append(prompt)

        print(f"Initializing vLLM with {args.model_name}...")
        os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
        
        llm = LLM(model=args.model_name,
                  tensor_parallel_size=args.gpus, 
                  trust_remote_code=True, 
                  quantization=args.quantization,
                  max_model_len=4096, 
                  gpu_memory_utilization=0.8,
                  enforce_eager=True 
            )
        
        sampling_params = SamplingParams(temperature=0.0, max_tokens=512)

        print("Starting batched diagnosis...")
        batch_size = 100
        total_batches = (len(prompts) + batch_size - 1) // batch_size
        
        for batch_idx in range(total_batches):
            start_idx = batch_idx * batch_size
            end_idx = min(start_idx + batch_size, len(prompts))
            batch_prompts = prompts[start_idx : end_idx]
            
            outputs = llm.generate(batch_prompts, sampling_params)
            
            with open(args.output_keep, 'a', encoding='utf-8') as f_keep, \
                 open(args.output_fix, 'a', encoding='utf-8') as f_fix, \
                 open(args.output_drop, 'a', encoding='utf-8') as f_drop:
                 
                for j, output in enumerate(outputs):
                    item = data_to_process[start_idx + j]
                    res_text = output.outputs[0].text
                    
                    decision_obj = parse_response(res_text)
                    
                    if decision_obj and 'scores' in decision_obj:
                        decision, reason = get_triage_decision(decision_obj['scores'])
                        
                        item['diagnosis_result'] = decision_obj
                        item['triage'] = decision
                        
                    else:
                        item['triage'] = "DROP"
                        item['diagnosis_result'] = {"error": "JSON Parse Error", "raw_text": res_text}
       
                    json_line = json.dumps(item, ensure_ascii=False) + "\n"
                    
                    if decision == 'KEEP':
                        f_keep.write(json_line)
                    elif decision == 'FIX':
                        f_fix.write(json_line)
                    else:
                        f_drop.write(json_line)
            
            current_total = processed_count + end_idx
            print(f"Progress: {current_total}/{len(all_data)}")

    # 4. 统计绘图
    final_stats = {
        "KEEP": get_processed_count([args.output_keep]),
        "FIX": get_processed_count([args.output_fix]),
        "DROP": get_processed_count([args.output_drop])
    }
    
    output_dir = os.path.dirname(args.output_keep)
    print(f"\nFinal Statistics: {final_stats}")
    plot_triage_stats(final_stats, output_dir)
    print(f"\nRun Complete.")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_path", type=str, default="alpaca/data/geo_sampled.json")
    parser.add_argument("--model_name", type=str, default="Qwen2.5-72B-Instruct-AWQ") 
    parser.add_argument("--output_keep", type=str, default="alpaca/data/data_keep.jsonl")
    parser.add_argument("--output_fix", type=str, default="alpaca/data/data_to_fix.jsonl")
    parser.add_argument("--output_drop", type=str, default="alpaca/data/data_dropped.jsonl")
    parser.add_argument("--gpus", type=int, default=2)
    parser.add_argument("--quantization", type=str, default="awq_marlin")
    args = parser.parse_args()
    main(args)