import argparse
import json
import re
import os
from vllm import LLM, SamplingParams


def construct_guidance(scores, diagnostic_description):
    C = int(scores.get('compliance', 5))
    R = int(scores.get('richness', 5))
    S = int(scores.get('soundness', scores.get('logic', 5)))
    P = int(scores.get('presentation', 5))

    p1_directives = []
    if S <= 3:
        p1_directives.append("1. [LOGIC/FACT] Scrutinize premises. If a premise is FALSE, correct it immediately. If reasoning relies on weak assumptions, make them explicit or rigorously derived.")
    if C < 5:
        p1_directives.append("2. [ALIGNMENT] Enforce user constraints (format, length, negatives). Rewrite affected sections to strictly satisfy them.")

    p2_directives = []
    if R <= 3:
        p2_directives.append("1. [DEPTH] Vertically expand the explanation. Instantiate abstract claims with concrete examples.")
    if P <= 2:
        p2_directives.append("2. [FORMAT] Reorganize text into clear hierarchy (headers, bullets).")
        
    if not p1_directives and not p2_directives:
        p2_directives.append("1. [POLISH] Review for clarity and flow; make minor improvements.")

    remediation_guidance = f"""
Remediation Guidance

Diagnostic Description: "{diagnostic_description}"

[Phase 1: Structural Remediation]
(Focus: Correctness & Compliance. NO Fluff.)
{chr(10).join(p1_directives) if p1_directives else "No critical structural defects detected. Maintain integrity."}

[Phase 2: Content Enrichment and Expression Refinement]
(Focus: Depth & Readability. Only proceed if Phase 1 is secure.)
{chr(10).join(p2_directives) if p2_directives else "No enrichment required."}

Global Constraints
- FACTUAL SUPREMACY: Accuracy overrides all other instructions.
- Minimal Intervention: Do not rewrite parts that are already perfect.
- Integration: Phase 2 edits must be woven seamlessly into the Phase 1 base.
"""
    
    metadata = {
        "phase1_active": bool(p1_directives),
        "phase2_active": bool(p2_directives),
        "defect_profile": {
                "soundness_risk": S <= 3,
                "compliance_risk": C < 5,
                "richness_gap": R <= 3,
                "presentation_issue": P <= 2
            }
    }

    return remediation_guidance, metadata


REMEDIATION_PROMPT = """
You are a precision editor for LLM-generated data. 
You do not freely rewrite text. 
You revise the response strictly following the repair guidance below, using a two-phase process. 
    
Input Data
User Instruction: {instruction}
User Context: {input}
Original Response: {old_output}

Remediation Guidance
{remediation_guidance}

Execution
Step 1: Repair Plan
- Identify which repair operations are activated in Phase 1 and Phase 2.
- Explicitly state how you will resolve conflicts (e.g., Accuracy > Alignment).

Step 2: Refined Response
- Produce the final repaired output enclosed within <fixed response> tags. 
- Apply Phase 2 modifications only after Phase 1 corrections are completed. 

OUTPUT FORMAT
Repair Plan:
<Brief plan>

<fixed_response>
[The final refined response goes here]
</fixed_response>
"""

def parse_model_output(raw_output):
    pattern_standard = r"<fixed_response>\s*(.*?)\s*</fixed_response>"
    match = re.search(pattern_standard, raw_output, re.DOTALL | re.IGNORECASE)
    if match:
        return match.group(1).strip()

    pattern_truncated = r"<fixed_response>\s*(.*)"
    match_truncated = re.search(pattern_truncated, raw_output, re.DOTALL | re.IGNORECASE)
    if match_truncated:
        content = match_truncated.group(1).strip()
        return content.replace("```", "").strip()

    if "Repair Plan:" in raw_output:
        parts = raw_output.split("Repair Plan:", 1)
        if len(parts) > 1:
            remaining = parts[1]
            segments = remaining.split('\n\n')
            if len(segments) >= 2:
                return segments[-1].strip()
    return None

def load_data(input_path):
    data = []
    print(f"Loading data from {input_path}...")
    if not os.path.exists(input_path):
        raise FileNotFoundError(f"Input file not found: {input_path}")
    with open(input_path, 'r', encoding='utf-8') as f:
        for line in f:
            if line.strip():
                data.append(json.loads(line))
    print(f"Loaded {len(data)} records to fix.")
    return data

def get_processed_count(output_path):
    count = 0
    if os.path.exists(output_path):
        with open(output_path, 'r', encoding='utf-8') as f:
            for line in f:
                if line.strip():
                    count += 1
    return count


def main(args):
    all_data = load_data(args.input_path)
    if not all_data: return

    processed_count = get_processed_count(args.output_path)
    if processed_count > 0:
        print(f"Resuming from {processed_count}...")
        if processed_count < len(all_data):
            data_to_process = all_data[processed_count:]
        else:
            print("Done.")
            return
    else:
        print("Starting fresh run...")
        data_to_process = all_data
        os.makedirs(os.path.dirname(args.output_path), exist_ok=True)


    print(f"Initializing Refiner with {args.model_name}...")
    os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
    
    llm = LLM(model=args.model_name,
              tensor_parallel_size=args.gpus, 
              trust_remote_code=True, 
              quantization=args.quantization,
              max_model_len=args.max_len, 
              gpu_memory_utilization=0.8,
              enforce_eager=True
    )
    
    tokenizer = llm.get_tokenizer()

    print(f"Composing Dynamic Protocols & Filtering (Max: {args.max_len})...")
    
    final_prompts = []
    final_data_items = []
    skipped_count = 0
    stats = {"Phase1_Active": 0, "Phase2_Active": 0}

    SAFE_LIMIT = args.max_len - 16

    for item in data_to_process:
        diagnosis_result = item.get('diagnosis_result', {})
        scores = diagnosis_result.get('scores', {})
        diagnostic_description = diagnosis_result.get('diagnostic_description', 'No specific hint provided.')
        remediation_guidance, meta = construct_guidance(scores, diagnostic_description)
        
        prompt = REMEDIATION_PROMPT.format(
            instruction=item['instruction'], 
            input=item.get('input', ''), 
            old_output=item['output'],
            remediation_guidance=remediation_guidance
        )

        token_ids = tokenizer.encode(prompt)
        
        if len(token_ids) <= SAFE_LIMIT:
            final_prompts.append(prompt)
            final_data_items.append(item)
            
            if meta["phase1_active"]: stats["Phase1_Active"] += 1
            if meta["phase2_active"]: stats["Phase2_Active"] += 1
        else:
            skipped_count += 1
            if skipped_count % 10 == 0:
                print(f"[Warning] Skipped prompt with {len(token_ids)} tokens.")

    print(f"Pipeline Stats (Filtered): {stats}")
    print(f"Original: {len(data_to_process)} | Skipped: {skipped_count} | Final: {len(final_prompts)}")

    if not final_prompts:
        print("No valid prompts left. Exiting.")
        return

    print("Starting batched refinement...")
    sampling_params = SamplingParams(temperature=0.3, max_tokens=1024) 
    
    batch_size = 200 
    total_batches = (len(final_prompts) + batch_size - 1) // batch_size
    success_count = 0
    
    for batch_idx in range(total_batches):
        start_idx = batch_idx * batch_size
        end_idx = min(start_idx + batch_size, len(final_prompts))
        
        batch_prompts = final_prompts[start_idx : end_idx]
        
        outputs = llm.generate(batch_prompts, sampling_params)
        
        with open(args.output_path, 'a', encoding='utf-8') as f_out:
            for j, output in enumerate(outputs):
                item = final_data_items[start_idx + j]
                res_text = output.outputs[0].text
                
                fixed_response = parse_model_output(res_text)
                
                item['original_output'] = item['output']
                
                final_item = {
                    "instruction": item.get('instruction', ''),
                    "input": item.get('input', ''),
                    "original_output": item.get('output', ''),  
                    "diagnosis_result": item.get('diagnosis_result', {}),
                    "is_refined": False,
                    "output": item.get('output', '') 
                }
                
                if fixed_response:
                    final_item['output'] = fixed_response
                    final_item['is_refined'] = True
                    success_count += 1
                else:
                    final_item['refinement_error'] = "Parse Failed"
                    
                
                json_line = json.dumps(item, ensure_ascii=False) + "\n"
                f_out.write(json_line)
        
        current_total = processed_count + batch_idx * batch_size + len(batch_prompts)
        print(f"Batch {batch_idx+1}/{total_batches} done. Total processed: {current_total}")

    print(f"Done. Output: {args.output_path}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_path", type=str, default="alpaca/data/data_to_fix.jsonl")
    parser.add_argument("--model_name", type=str, default="Qwen2.5-72B-Instruct-AWQ")
    parser.add_argument("--output_path", type=str, default="alpaca/data/data_refined.jsonl")
    parser.add_argument("--gpus", type=int, default=2)
    parser.add_argument("--quantization", type=str, default="awq_marlin")
    parser.add_argument("--max_len", type=int, default=4096, help="Max length for prompt filtering")
    args = parser.parse_args()
    main(args)