import os
import ast
import json
import random
import argparse
import pandas as pd
import time
import re
import torch
from typing import List, Dict, Tuple
from tqdm import trange
from transformers import AutoTokenizer, AutoModelForCausalLM

def prepare_messages(user_prompt, prompt_format=None):
    """Prepare messages for chat template"""
    return [{"role": "user", "content": user_prompt}]

def generate_local(gen_model, gen_tokenizer, device, sys_prompt, user_prompt,
                   max_new_tokens=1024, temperature=0.7, top_p=0.9, prompt_format=None):
    """Unified generation function"""
    messages = prepare_messages(user_prompt, prompt_format=prompt_format)
    
    if sys_prompt and sys_prompt.strip():
        messages.insert(0, {"role": "system", "content": sys_prompt})
    
    inputs = gen_tokenizer.apply_chat_template(
        messages, 
        return_tensors="pt", 
        padding=True, 
        truncation=True,
        return_dict=True,
        add_generation_prompt=True
    )
    
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    with torch.no_grad():
        out = gen_model.generate(
            **inputs, 
            max_new_tokens=max_new_tokens, 
            do_sample=temperature>0,
            temperature=temperature, 
            top_p=top_p, 
            pad_token_id=gen_tokenizer.eos_token_id
        )
    
    seq = out[0]
    inp_len = inputs["input_ids"].shape[1]
    text = gen_tokenizer.decode(seq[inp_len:], skip_special_tokens=True).strip()
    if text.lower().startswith("assistant"):
        text = text[len("assistant"):].lstrip()
    return text

def extract_python_code(text):
    """Extract Python code from generated text"""
    code_block_pattern = r'```python\n(.*?)\n```'
    matches = re.findall(code_block_pattern, text, re.DOTALL)
    if matches:
        return matches[0].strip()
    
    lines = text.split('\n')
    code_lines = []
    in_function = False
    
    for line in lines:
        if line.strip().startswith('def '):
            in_function = True
            code_lines.append(line)
        elif in_function:
            if line.strip() == '' or line.startswith(' ') or line.startswith('\t'):
                code_lines.append(line)
            else:
                if not line.strip().startswith('def '):
                    break
                code_lines.append(line)
    
    if code_lines:
        return '\n'.join(code_lines)
    
    return text.strip()

def parse_assert_test(assert_statement):
    """Parse assert statement to extract inputs and expected output"""
    try:
        expression = assert_statement.replace('assert ', '').strip()
        
        if '==' in expression:
            func_call, expected = expression.split('==', 1)
            func_call = func_call.strip()
            expected = expected.strip()
            
            func_match = re.match(r'(\w+)\((.*)\)', func_call)
            if func_match:
                func_name = func_match.group(1)
                args_str = func_match.group(2)
                
                try:
                    args = eval(f"[{args_str}]")
                    expected_result = eval(expected)
                    
                    return {
                        'function_name': func_name,
                        'inputs': args,
                        'expected': expected_result
                    }
                except:
                    return None
    except:
        return None
    
    return None

def convert_mbpp_tests(test_list):
    """Convert MBPP assert-style tests to our format"""
    converted_tests = []
    
    for test_assert in test_list:
        parsed = parse_assert_test(test_assert)
        if parsed:
            converted_tests.append({
                'inputs': parsed['inputs'],
                'expected': parsed['expected']
            })
    
    return converted_tests

def run_tests_detailed(code, test_cases):
    """Run test cases and return detailed results"""
    if not test_cases:
        return 0.0, 0, 0, [], []
    
    passed = 0
    total = len(test_cases)
    errors = []
    test_results = []
    
    for i, test_case in enumerate(test_cases):
        try:
            namespace = {}
            exec(code, namespace)
            
            func_name = None
            for name, obj in namespace.items():
                if callable(obj) and not name.startswith('_'):
                    func_name = name
                    break
            
            if func_name is None:
                error_msg = f"Test {i+1}: No callable function found"
                errors.append(error_msg)
                test_results.append({
                    'test_id': i+1,
                    'inputs': test_case.get('inputs', []),
                    'expected': test_case.get('expected'),
                    'actual': None,
                    'passed': False,
                    'error': error_msg
                })
                continue
            
            inputs = test_case.get('inputs', [])
            expected = test_case.get('expected')
            
            if isinstance(inputs, list):
                result = namespace[func_name](*inputs)
            else:
                result = namespace[func_name](inputs)
            
            test_passed = result == expected
            if test_passed:
                passed += 1
            
            test_results.append({
                'test_id': i+1,
                'inputs': inputs,
                'expected': expected,
                'actual': result,
                'passed': test_passed,
                'error': None
            })
            
            if not test_passed:
                errors.append(f"Test {i+1}: Expected {expected}, got {result}")
                
        except Exception as e:
            error_msg = f"Test {i+1}: Runtime error - {str(e)}"
            errors.append(error_msg)
            test_results.append({
                'test_id': i+1,
                'inputs': test_case.get('inputs', []),
                'expected': test_case.get('expected'),
                'actual': None,
                'passed': False,
                'error': str(e)
            })
    
    pass_rate = passed / total if total > 0 else 0.0
    return pass_rate, passed, total, errors, test_results

def evaluate_code_detailed(gen_model, gen_tokenizer, device, code, problem_description, test_cases):
    """Evaluate code with detailed test feedback"""
    if not code.strip():
        return 1.0, 0, 0, [], []
    
    pass_rate, passed, total, errors, test_results = run_tests_detailed(code, test_cases)
    test_score = pass_rate * 10
    
    # Only use test score for evaluation
    final_score = test_score
    
    return final_score, passed, total, test_results, errors

class Plan2AlignGenerator:
    """Plan2Align inspired code generator focusing on promising state exploration"""
    
    def __init__(self, gen_model, gen_tokenizer, device):
        self.gen_model = gen_model
        self.gen_tokenizer = gen_tokenizer
        self.device = device
        self.promising_approaches = []  # Store promising code segments/approaches
    
    def identify_promising_elements(self, code, test_results, iteration):
        """Identify promising elements from current solution"""
        promising_elements = []
        
        # Analyze which parts of the code work well
        passed_tests = [t for t in test_results if t['passed']]
        failed_tests = [t for t in test_results if not t['passed']]
        
        if passed_tests:
            promising_elements.append({
                'type': 'working_logic',
                'description': f"Successfully handles {len(passed_tests)} test cases",
                'test_cases': [{'inputs': t['inputs'], 'expected': t['expected']} for t in passed_tests],
                'code_snippet': code,
                'iteration': iteration
            })
        
        # Extract potential algorithmic insights
        if 'def ' in code:
            lines = code.split('\n')
            for i, line in enumerate(lines):
                if any(keyword in line.lower() for keyword in ['if', 'for', 'while', 'return']):
                    promising_elements.append({
                        'type': 'control_structure',
                        'description': f"Control logic: {line.strip()}",
                        'context': '\n'.join(lines[max(0, i-1):i+2]),
                        'iteration': iteration
                    })
        
        return promising_elements
    
    def generate_exploration_candidates(self, problem_description, test_cases, history, iteration, num_candidates=3):
        """Generate candidates by exploring promising directions"""
        candidates = []
        
        # Collect all promising elements from history
        all_promising = []
        for iter_data in history.values():
            if 'promising_elements' in iter_data:
                all_promising.extend(iter_data['promising_elements'])
        
        # Strategy 1: Build upon most promising working logic
        if all_promising:
            working_elements = [e for e in all_promising if e['type'] == 'working_logic']
            if working_elements:
                best_working = max(working_elements, key=lambda x: len(x.get('test_cases', [])))
                candidate = self.generate_from_promising_base(
                    problem_description, test_cases, best_working, iteration
                )
                if candidate:
                    candidates.append(candidate)
        
        # Strategy 2: Diverse exploration with high temperature
        for i in range(num_candidates - len(candidates)):
            if i == 0 and iteration <= 2:
                # Early iterations: encourage diverse approaches
                candidate = self.generate_diverse_solution(problem_description, test_cases, iteration, temperature=0.8)
            else:
                # Later iterations: revisit and refine promising approaches
                candidate = self.generate_refined_solution(problem_description, test_cases, all_promising, iteration)
            
            if candidate:
                candidates.append(candidate)
        
        # Ensure we have enough candidates for exploration
        while len(candidates) < num_candidates:
            candidate = self.generate_exploratory_solution(problem_description, test_cases, iteration)
            if candidate:
                candidates.append(candidate)
        
        return candidates[:num_candidates]
    
    def generate_from_promising_base(self, problem_description, test_cases, promising_element, iteration):
        """Generate solution building on promising base"""
        
        working_cases = promising_element.get('test_cases', [])
        base_code = promising_element.get('code_snippet', '')
        
        sys_prompt = """You are a Python expert. 
You have a PROMISING partial solution that handles some test cases correctly. 
Build upon this working foundation while extending it to handle all cases.
Maintain the working logic and expand it thoughtfully. Output ONLY clean Python code."""
        
        user_prompt = f"""Problem: {problem_description}

PROMISING BASE SOLUTION (handles {len(working_cases)} cases correctly):
```python
{base_code}
```

Working test cases:
{working_cases}

ALL test cases that must be handled:
{[{'inputs': test['inputs'], 'expected': test['expected']} for test in test_cases]}

Build upon the working parts of the base solution and extend it to handle all test cases:"""
        
        try:
            response = generate_local(self.gen_model, self.gen_tokenizer, self.device,
                                    sys_prompt, user_prompt, max_new_tokens=512, temperature=0.4)
            return extract_python_code(response)
        except Exception as e:
            print(f"      Error generating from promising base: {e}")
            return None
    
    def generate_diverse_solution(self, problem_description, test_cases, iteration, temperature=0.7):
        """Generate diverse solution for exploration"""
        
        approaches = [
            "Think step by step and use a direct algorithmic approach",
            "Consider edge cases first and build a robust solution", 
            "Use built-in Python functions and libraries effectively",
            "Focus on mathematical or logical patterns in the test cases",
            "Implement with clear conditional logic and error handling"
        ]
        
        approach = approaches[iteration % len(approaches)]
        
        sys_prompt = f"""You are a Python expert.
{approach}
Focus on creating a WORKING solution that passes tests. Output ONLY clean Python code."""
        
        user_prompt = f"""Problem: {problem_description}

Test cases to handle:
{[{'inputs': test['inputs'], 'expected': test['expected']} for test in test_cases]}

Write a complete Python solution focusing on correctness:"""
        
        try:
            response = generate_local(self.gen_model, self.gen_tokenizer, self.device,
                                    sys_prompt, user_prompt, max_new_tokens=512, temperature=temperature)
            return extract_python_code(response)
        except Exception as e:
            print(f"      Error generating diverse solution: {e}")
            return None
    
    def generate_refined_solution(self, problem_description, test_cases, promising_elements, iteration):
        """Generate solution by refining promising approaches"""
        
        if not promising_elements:
            return self.generate_diverse_solution(problem_description, test_cases, iteration, 0.6)
        
        # Select promising elements to revisit
        recent_promising = [e for e in promising_elements if e.get('iteration', 0) >= max(1, iteration-2)]
        if not recent_promising:
            recent_promising = promising_elements[-3:]  # Take most recent ones
        
        insights = []
        for elem in recent_promising:
            if elem['type'] == 'working_logic':
                insights.append(f"- Successfully handled cases: {elem['description']}")
            elif elem['type'] == 'control_structure':
                insights.append(f"- Useful pattern: {elem['description']}")
        
        sys_prompt = """You are a Python expert.
You have insights from previous promising attempts. Use these insights to create an improved solution.
Focus on combining the best aspects while avoiding previous pitfalls. Output ONLY clean Python code."""
        
        user_prompt = f"""Problem: {problem_description}

Previous promising insights:
{chr(10).join(insights)}

All test cases:
{[{'inputs': test['inputs'], 'expected': test['expected']} for test in test_cases]}

Create an improved solution by leveraging these insights:"""
        
        try:
            response = generate_local(self.gen_model, self.gen_tokenizer, self.device,
                                    sys_prompt, user_prompt, max_new_tokens=512, temperature=0.5)
            return extract_python_code(response)
        except Exception as e:
            print(f"      Error generating refined solution: {e}")
            return None
    
    def generate_exploratory_solution(self, problem_description, test_cases, iteration):
        """Generate exploratory solution with high diversity"""
        
        sys_prompt = """You are a Python expert.
Try a completely different approach than typical solutions. Be creative and think outside the box.
Output ONLY clean Python code that solves the problem."""
        
        user_prompt = f"""Problem: {problem_description}

Test cases:
{[{'inputs': test['inputs'], 'expected': test['expected']} for test in test_cases]}

Write a creative Python solution:"""
        
        try:
            # Use high temperature for exploration
            response = generate_local(self.gen_model, self.gen_tokenizer, self.device,
                                    sys_prompt, user_prompt, max_new_tokens=512, 
                                    temperature=min(0.9, 0.6 + iteration * 0.1))
            return extract_python_code(response)
        except Exception as e:
            print(f"      Error generating exploratory solution: {e}")
            return None

def load_mbpp_jsonl(file_path):
    """Load MBPP JSONL dataset"""
    data = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            if line.strip():
                item = json.loads(line.strip())
                data.append(item)
    return data

def run_plan2align_mbpp(args):
    """Main function: Run Plan2Align on MBPP JSONL"""
    
    # Setup device
    device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Load model
    gen_model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct"
    print(f"Loading model: {gen_model_name}")
    
    gen_tokenizer = AutoTokenizer.from_pretrained(gen_model_name)
    if gen_tokenizer.pad_token is None:
        gen_tokenizer.pad_token = gen_tokenizer.eos_token
    
    gen_model = AutoModelForCausalLM.from_pretrained(
        gen_model_name, torch_dtype=torch.bfloat16, device_map={"": device}, trust_remote_code=True
    ).to(device)
    gen_model.eval()
    
    print("Model loaded successfully!")
    
    # Load data
    data = load_mbpp_jsonl(args.input_file)
    os.makedirs(args.output_folder, exist_ok=True)
    
    print(f"Loaded {len(data)} problems from MBPP JSONL")
    print(f"Processing problems {args.start} to {args.end-1}")
    print("Using Plan2Align strategy with promising state exploration!")
    
    # Initialize generator
    plan2align_generator = Plan2AlignGenerator(gen_model, gen_tokenizer, device)
    
    for idx, item in enumerate(data):
        if idx < args.start or idx >= args.end:
            continue
        
        task_id = item.get('task_id', idx)
        problem_description = item.get('text', '')
        test_list = item.get('test_list', [])
        
        # Convert test cases
        test_cases = convert_mbpp_tests(test_list)
        
        print(f"\nProblem {task_id}: {problem_description[:80]}...")
        print(f"  Converted {len(test_cases)} test cases")
        
        # Generate initial solution with exploration
        sys_prompt = """You are a Python expert. 
Generate a working solution for the problem. Focus on correctness over complexity.
Output only clean Python code."""
        user_prompt = f"Problem: {problem_description}\n\nWrite a Python solution:"
        
        initial_code = generate_local(gen_model, gen_tokenizer, device, sys_prompt, user_prompt, temperature=0.3)
        initial_code = extract_python_code(initial_code)
        
        # Detailed evaluation
        initial_score, initial_passed, initial_total, initial_test_results, initial_errors = evaluate_code_detailed(
            gen_model, gen_tokenizer, device, initial_code, problem_description, test_cases
        )
        
        print(f"  Initial: {initial_score:.2f} (tests: {initial_passed}/{initial_total})")
        
        # Initialize tracking
        best_code = initial_code
        best_score = initial_score
        best_test_results = initial_test_results
        best_passed = initial_passed
        
        # Identify initial promising elements
        initial_promising = plan2align_generator.identify_promising_elements(
            initial_code, initial_test_results, 0
        )
        
        history = {
            0: {
                "task_id": task_id,
                "problem": problem_description,
                "test_cases": test_cases,
                "original_tests": test_list,
                "code": initial_code,
                "score": initial_score,
                "test_passed": initial_passed,
                "test_total": initial_total,
                "promising_elements": initial_promising
            }
        }
        
        # If initial solution passes all tests, still do one iteration for robustness
        if initial_passed == initial_total:
            print(f"  All tests passed initially! Doing one exploration iteration for robustness.")
        
        # Iterative improvement with Plan2Align
        for iteration in range(1, args.max_iterations + 1):
            print(f"  Iteration {iteration}")
            
            # Generate candidates using Plan2Align approach
            print(f"    Plan2Align: Exploring promising directions...")
            candidates = plan2align_generator.generate_exploration_candidates(
                problem_description, test_cases, history, iteration, args.num_candidates
            )
            
            print(f"    Generated {len(candidates)} exploration candidates")
            
            best_candidate = None
            best_candidate_score = best_score
            best_candidate_test_results = best_test_results
            best_candidate_passed = best_passed
            
            for j, candidate in enumerate(candidates):
                if candidate:
                    score, passed, total, test_results, errors = evaluate_code_detailed(
                        gen_model, gen_tokenizer, device, candidate, problem_description, test_cases
                    )
                    print(f"      Candidate {j+1}: {score:.2f} (tests: {passed}/{total})")
                    
                    # Prefer solutions with better test passage, then score
                    if (passed > best_candidate_passed) or (passed == best_candidate_passed and score > best_candidate_score):
                        best_candidate = candidate
                        best_candidate_score = score
                        best_candidate_test_results = test_results
                        best_candidate_passed = passed
            
            # Update best solution if improvement found
            if best_candidate and (best_candidate_passed > best_passed or 
                                 (best_candidate_passed == best_passed and best_candidate_score > best_score)):
                best_code = best_candidate
                best_score = best_candidate_score
                best_test_results = best_candidate_test_results
                best_passed = best_candidate_passed
                print(f"    ✓ Improvement found!")
            
            # Identify promising elements from this iteration
            iteration_promising = plan2align_generator.identify_promising_elements(
                best_code, best_test_results, iteration
            )
            
            history[iteration] = {
                "task_id": task_id,
                "problem": problem_description,
                "test_cases": test_cases,
                "original_tests": test_list,
                "code": best_code,
                "score": best_score,
                "test_passed": best_passed,
                "test_total": initial_total,
                "promising_elements": iteration_promising
            }
            
            print(f"    Best overall: {best_score:.2f} (tests: {best_passed}/{initial_total})")
            
            # Early stopping if all tests pass and we've done some exploration
            if best_passed == initial_total and iteration >= 2:
                print(f"    All tests passed with exploration! Moving to next problem.")
                break
        
        # Save results
        output_path = os.path.join(args.output_folder, f"problem_{task_id}.json")
        with open(output_path, 'w', encoding='utf-8') as f:
            json.dump(history, f, ensure_ascii=False, indent=2)
    
    print("\nPlan2Align MBPP completed!")

def evaluate_results_strict(
    folder_path: str,
    max_iteration: int,
    output_file: str,
    max_id: int = 256,
    id_start: int = 1,
    filename_prefix: str = "problem_"
):
    """
    Evaluate results for problem_{id_start}..problem_{max_id}.json (inclusive).

    Rules:
      - Missing file => counted with zeros, error='missing_file'
      - Load error => zeros, error='load_error: ...'
      - File exists but no valid iterations <= max_iteration => zeros, error='no_valid_iters'
      - Valid file: pick the iteration (<= max_iteration) with max 'score'

    Output:
      - CSV with columns:
        task_id, problem, best_iteration, score, test_passed, test_total, pass_rate, code, error
      - Console summary with averages (including zeros from errors)
    """
    import os, json
    import pandas as pd

    expected_ids = list(range(id_start, max_id + 1))
    records = []

    for task_id in expected_ids:
        file_path = os.path.join(folder_path, f"{filename_prefix}{task_id}.json")

        if not os.path.exists(file_path):
            records.append({
                'task_id': task_id,
                'problem': '',
                'best_iteration': None,
                'score': 0.0,
                'test_passed': 0,
                'test_total': 0,
                'pass_rate': 0.0,
                'code': None,
                'error': 'missing_file'
            })
            continue

        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                data = json.load(f)
        except Exception as e:
            records.append({
                'task_id': task_id,
                'problem': '',
                'best_iteration': None,
                'score': 0.0,
                'test_passed': 0,
                'test_total': 0,
                'pass_rate': 0.0,
                'code': None,
                'error': f'load_error: {e}'
            })
            continue

        try:
            valid_iters = {}
            for k, v in data.items():
                try:
                    ki = int(k)
                except Exception:
                    continue
                if ki <= max_iteration:
                    valid_iters[ki] = v
        except Exception as e:
            valid_iters = {}

        if not valid_iters:
            records.append({
                'task_id': data.get('task_id', task_id),
                'problem': (data.get('problem') or ''),
                'best_iteration': None,
                'score': 0.0,
                'test_passed': 0,
                'test_total': 0,
                'pass_rate': 0.0,
                'code': None,
                'error': 'no_valid_iters'
            })
            continue

        def safe_score(it_dict, default=0.0):
            try:
                s = it_dict.get('score', default)
                return float(s) if s is not None else default
            except Exception:
                return default

        best_iter = max(valid_iters, key=lambda x: safe_score(valid_iters[x], 0.0))
        best_result = valid_iters[best_iter]

        score = safe_score(best_result, 0.0)
        test_passed = int(best_result.get('test_passed', 0) or 0)
        test_total = int(best_result.get('test_total', 0) or 0)
        pass_rate = (test_passed / test_total) if test_total > 0 else 0.0

        records.append({
            'task_id': best_result.get('task_id', task_id),
            'problem': (best_result.get('problem') or '')[:100],
            'best_iteration': best_iter,
            'score': score,
            'test_passed': test_passed,
            'test_total': test_total,
            'pass_rate': pass_rate,
            'code': best_result.get('code'),
            'error': None
        })

    df = pd.DataFrame(records)
    df.to_csv(output_file, index=False)

    if not df.empty:
        avg_score = df['score'].mean()
        avg_pass = df['pass_rate'].mean()
        perfect = (df['pass_rate'] >= 0.999).sum()
        errors = df['error'].notna().sum()

        print(f"\nResults saved to {output_file}")
        print(f"Problems considered: {len(df)} (IDs {id_start}..{max_id})")
        print(f"Errors (missing/load/no_valid_iters): {errors}")
        print(f"Average Score: {avg_score:.3f}")
        print(f"Average Pass Rate: {avg_pass:.3f}")
        print(f"Problems with perfect pass: {perfect}")
    else:
        print("No results produced (empty DataFrame).")


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Plan2Align for MBPP JSONL with Local Llama3.1")
    
    # Basic parameters
    parser.add_argument("--input_file", type=str, required=True, help="MBPP JSONL dataset file")
    parser.add_argument("--output_folder", type=str, required=True, help="Output folder")
    
    # Generation parameters
    parser.add_argument("--max_iterations", type=int, default=5, help="Max iterations")
    parser.add_argument("--num_candidates", type=int, default=2, help="Candidates per iteration")
    parser.add_argument("--start", type=int, default=0, help="Start index")
    parser.add_argument("--end", type=int, default=100, help="End index")
    
    # Evaluation mode
    parser.add_argument("--evaluate", action="store_true", help="Evaluation mode")
    parser.add_argument("--eval_folder", type=str, help="Evaluation folder")
    parser.add_argument("--eval_max_iter", type=int, default=0, help="Max iteration to evaluate")
    parser.add_argument("--eval_output", type=str, default="plan2align_results.csv", help="Evaluation output file")
    parser.add_argument("--eval_range", type=int, default=256, help="Number of problems to evaluate")
    
    args = parser.parse_args()
    
    if args.evaluate:
        evaluate_results_strict(args.eval_folder, args.eval_max_iter, args.eval_output, args.eval_range)
    else:
        run_plan2align_mbpp(args)