from models import *
import itertools
import json
import random
from collections import Counter, defaultdict
import re
from typing import Tuple, List, Dict


def extract_key_phases(llm_output):
    patterns = [
        r'(?:\d+\.\s*)(.*?)(?=,\s*\n|\n\d+\.|$)',  # Numbered with commas
        r'(?:\d+\.\s*)(.*?)(?=\n|$)',  # Numbered without commas
        r'(?:[-*]\s*)(.*?)(?=\n|$)'  # Bullet points
    ]

    phases = []
    for pattern in patterns:
        phases = re.findall(pattern, llm_output, re.MULTILINE)
        if phases:
            break

    cleaned_phases = []
    for phase in phases:
        phase = phase.strip()
        phase = re.sub(r'[,\.;]+$', '', phase).strip()
        if (phase and
                len(phase.split()) >= 3 and  # At least 3 words for meaningful sub-goal
                re.match(r'^[A-Za-z]', phase) and  # Starts with letter
                not re.search(r'[{}()\[\]<>]', phase)):  # No special brackets
            cleaned_phases.append(phase)

    return cleaned_phases


def generate_phase(task, args, max_attempts=5):
    system_prompt = '''You are an expert task decomposer. Your role is to analyze complex problems and break them down into essential high-level sub-goals (key phases). Each sub-goal should represent a critical milestone that moves toward solving the task.
**** Generate only the most essential sub-goals needed to complete this task, excluding all implementation details and optional steps.
**SUB-GOAL DEFINITION:**
Each sub-goal should specify WHAT needs to be accomplished, not HOW to do it. Focus on the key objectives that must be achieved.

**Output Requirements:**
- List sub-goal as numbered bullet points (1., 2., 3., ...)
- Keep sub-goal concise and actionable.
- ***** Don't generate too many sub-goal *****.
- Maintain logical sequence from initial sub-goal to final sub-goal
- Ensure each sub-goal represents a distinct, essential one
- Exclude implementation details, calculations, or explanatory text

**Example Demonstrations:**

1. **Geometry Problem:**
   Task: "Find the area of a triangle with base 8cm and height 5cm"

   Key Sub-Goals:
   1. Identify the area formula for triangles ,
   2. Extract given dimensions from the problem ,
   3. Compute the area using the formula 
   4. Final verification or solution step and give the final answer


**Output Format Strictly Follow This Pattern:**
1. [Action-oriented sub-goal description] ,
2. [Next essential sub-goal] ,
...
N. [Final verification or solution step and give the final answer] 

**Validation Checklist:**
✓ Each sub-goal should specify WHAT needs to be accomplished, not HOW to do it.
✓ Phases cover all critical aspects of the task
✓ Logical progression from start to finish
✓ No implementation details or calculations
✓ Suitable for the given task context

**Critical Reminders:**
- Phases should answer "what needs to be done" not "how to do it"
- Avoid transitional words ("then", "next", "after")
- Exclude mathematical symbols, formulas, or specific methods
- Maintain consistent verb tense and clarity
- Ensure sub-goals are truly sequential and complementary

'''
    attempts = 0

    while attempts < max_attempts:
        phase_output = gpt(task, system_prompt, model=args.backend)
        phases = extract_key_phases(phase_output)
        if phases:
            return phases

        attempts += 1
        print(f"Format validation failed. Attempt {attempts}/{max_attempts}. Retrying...")

    # If all attempts fail, raise error
    raise ValueError(f"Failed to generate properly formatted output after {max_attempts} attempts")


def extract_skills_list(skill_text, min_skills=2, max_skills=5):
    patterns = [
        r'^\s*[-*•→]\s*\[?(.+?)\]?\s*$',
        r'^\s*\d+\.\s*\[?(.+?)\]?\s*$',
        r'^\s*[a-z]\)\s*\[?(.+?)\]?\s*$',
        r'^\s*\[(.+?)\]\s*$'
    ]

    skills = []
    for pattern in patterns:
        found = re.findall(pattern, skill_text, re.MULTILINE)
        if found:
            skills.extend(found)
            if len(skills) >= max_skills:
                break

    if not (min_skills <= len(skills) <= max_skills):
        return None

    valid_skills = []
    for skill in skills:
        skill = skill.strip()
        if skill and len(skill) >= 2:
            skill = re.sub(r'^\s*\[|\]\s*$', '', skill)
            valid_skills.append(skill)

    return valid_skills if len(valid_skills) >= min_skills else None


def generate_skills(
        current_goal,
        current_step,
        args,
        min_skills=6,
        max_skills=6,
        max_attempts=5
):
    system_prompt = f'''You are a Goal-Oriented Skill Generator. Given the current step and the goal which needs to be achieved, Generate between {min_skills} and {max_skills} practical methods (skills) to achieve the specified goal.
**Current step: {current_step}
**Goal:** {current_goal}

****The skills should help to reach the goal as soon as possible!!!!!

**Skill Definition:**
Each skill should be a concrete, actionable method that:
1. Directly contributes to achieving the goal according to current step
2. Represents a distinct approach or technique
3. Is executable without external knowledge

**Output Requirements:**
- Generate between {min_skills} and {max_skills} skills
- Each skill must start with an action verb
- Format each skill as a bullet point ("- [skill description]")
- Keep skills concise (5-15 words)
- Exclude explanations or examples

**Quality Validation:**
✓ Each skill is a distinct method (not a restatement)
✓ Skills cover different aspects of the goal
✓ Methods are practical and executable
✓ Avoid overlapping or redundant skills

=== COMPLETE EXAMPLES ===
Example 1:
Current step: "I have calculated total sales amount and commission rate"
Goal: "Determine commission per sale"
Skills:
- Multiply sale amount by commission rate
- Convert percentage to decimal
- Round result to 2 decimal places


=== FORMAT REQUIREMENTS ===
Output MUST be:
- [Skill 1]
- [Skill 2]
...
- [Skill n] (where n is between {min_skills} and {max_skills})

**Critical Rules:**
- STRICTLY use the given format
- NO numbering or other formats
- NO additional text outside bullet points
- Skills must answer "how to achieve the goal based on current step"
'''

    prompt = f'''

Generate between {min_skills} and {max_skills} practical methods to achieve:
Goal: "{current_goal}" according to Current step: {current_step}

Output ONLY the bullet-pointed skills:'''

    for attempt in range(max_attempts):
        try:

            skills_output = gpt(prompt, system_prompt, model=args.backend)

            skills = extract_skills_list(skills_output, min_skills, max_skills)

            if len(skills) >= min_skills and len(skills) <= max_skills:
                return skills

            print(
                f"Attempt {attempt + 1}: Generated {len(skills) if skills else 0} skills (need {min_skills}-{max_skills})")
        except Exception as e:
            print(f"Attempt {attempt + 1} failed with error: {str(e)}")


def get_proposals(z, args, phase, skill):
    proposals = []
    for i in range(len(skill)):
        system_prompt = f'''You are a heuristic assistant specialized in sub-goal-based problem solving.

**** Pay attention
1. For simple problems (e.g., basic equations), give the direct answer immediately
Consider a problem simple if it involves:
- ≤2 operations (e.g., solve 2x=10)
- Basic arithmetic (+, -, ×, ÷)
- Single-variable equations
2. Never explain obvious calculations (e.g., 2+3=5)

**CURRENT SUB-GOAL:** {phase}
**REQUIRED SKILL:** {skill[i]}

**TASK:** Generate exactly the next step that:
1. Directly applies the specified skill: "{skill[i]}"
2. Advances the current sub-goal: "{phase}"
3. Reach the goal as fast as possible !!!

**OUTPUT FORMAT RULES:**
- The next step should reach the goal as fast as possible.
- However, when the final step leads you to the final answer, give me only the numerical answer and print "###" before it, format as: ###[ANSWER]
For example: ###3.0, which should be at the last line
- Otherwise, provide a clear action step
- No explanations, just the step itself

**VERIFICATION CHECKLIST:**
✓ Does this step directly use the skill "{skill[i]}"?
✓ Does this step achieves the sub-goal "{phase}" as fast as possible?
✓ If final answer, does it start with "###"?
'''
        attempt = 0
        max_attempts = 3

        while attempt < max_attempts:
            try:
                proposal = gpt(z, system_prompt=system_prompt, model=args.backend)

                if proposal and proposal.strip():
                    proposals.append(proposal)
                    break
                else:
                    print(f"Empty response, retrying...")
                    attempt += 1

            except Exception as e:
                print(f" Error during GPT call: {e}, retrying...")
                attempt += 1
    return [z + _ + '\n' for _ in proposals]


def compare(pair_1, pair_2, backend, evaluate_times):
    compare_prompt = []
    compare_prompt.append(
        "You are a logical student. You should judge which of the two reasoning steps is better in both logical and correctness. You must only reply 1 or 2. If both inputs are equal, randomly print 1 or 2. Don't explain or reply anything else.\n")
    compare_prompt.append(
        "Find out which of the two analysis is better. You must only reply 1 or 2.If both inputs are equal, randomly print 1 or 2. Don't explain or reply anything else.\n")
    compare_prompt.append(
        "Compare the two analysis and find which is better. You must only reply 1 or 2. If both inputs are equal, randomly print 1 or 2. Don't explain or reply anything else.\n")

    value_outputs = []
    for i in range(evaluate_times):
        value_output = gpt(f"1: {pair_1}\n2: {pair_2}", system_prompt=compare_prompt[i], model=backend)
        value_outputs.append(value_output[0])
    values = []
    for value_output in value_outputs:
        numbers = re.findall(r'\d+', value_output)
        if (len(numbers)):
            values.append(numbers[0])
        else:
            values.append(random.randint(1, 2))
    counter = Counter(values)
    result = counter.most_common(1)[0][0]
    return int(result) - 1

def duel_bandit_select(pool, args, n=2, m=12, final_rounds=1):
    if not pool:
        return [], []

    if len(pool) == 1:
        return pool, []

    original_pool = pool.copy()
    current_pool = pool.copy()

    while len(current_pool) > m:
        edges = []
        for i in range(len(current_pool)):
            candidates = [j for j in range(len(current_pool)) if j != i]
            opponents = random.sample(candidates, min(n, len(candidates)))
            edges.extend([(i, opp) for opp in opponents])

        wins = defaultdict(int)
        comparisons = defaultdict(int)
        for i, j in edges:
            winner = compare(current_pool[i], current_pool[j], args.backend, evaluate_times=1)
            wins[winner] += 1
            comparisons[i] += 1
            comparisons[j] += 1

        new_pool = []
        for idx in range(len(current_pool)):
            score = wins.get(idx, 0) / max(1, comparisons.get(idx, 1))
            if score >= 0.5:
                new_pool.append(current_pool[idx])

        if len(new_pool) == len(current_pool):
            new_pool = random.sample(new_pool, min(m, len(new_pool)))

        current_pool = new_pool

    final_scores = defaultdict(float)
    for _ in range(final_rounds):
        edges = [(i, j) for i in range(len(current_pool))
                 for j in range(i + 1, len(current_pool))]

        for i, j in edges:
            winner = compare(current_pool[i], current_pool[j], args.backend, evaluate_times=1)
            final_scores[winner] += 1

    if not final_scores:
        selected_final = current_pool
    else:
        max_score = max(final_scores.values(), default=0)
        selected_indices = [i for i in final_scores if final_scores[i] == max_score]
        selected_final = [current_pool[i] for i in selected_indices]

    selected_set = set(selected_final)
    not_selected = [sol for sol in original_pool if sol not in selected_set]

    return selected_final, not_selected


def get_current_subgoal(
        phases_list,
        current_context,
        args,
        max_attempts=3
):
    system_prompt = f'''You are a Sub-goal Reasoning Engine. I will give the sub-goal list:{phases_list} and current thinking progress:{current_context}. Analyze the task progress and determine:
Which sub-goal should be actively worked on now.
If you think the current step has almost achieved one goal, don't choose that goal, and choose the next goals !!!!
***********
Choose the sub-goal from the list: {phases_list}, give me the number of index in the list.

**TASK ANALYSIS PROCESS:**
1. Compare current progress with each sub-goal's requirements
2. Identify the most immediate sub-goal that needs attention
3. Verify the selection matches logical progression

**OUTPUT FORMAT STRICTLY FOLLOW:**
[index]

**EXAMPLE 1:**
Phases: ["Data collection", "Analysis", "Validation"]
Current Context: "Finished gathering raw data, need to process it"
You choose [Analysis]
Output:
1

**CRITICAL RULES:**
- ******** Sub-goal MUST be from provided list and return in the number of index
- No explanations or additional text
- Sub-goal should logically follow from current_context
'''

    prompt = f'''
    CURRENT PROGRESS: {current_context},
    Goal list : {phases_list}
'''

    for attempt in range(max_attempts):
        try:
            response = gpt(prompt, system_prompt, model=args.backend)
            response = response.strip()

            match = re.search(r'^\s*(\d+)\s*$', response)
            if match:
                phase_index = int(match.group(1))
                if 0 <= phase_index < len(phases_list):
                    phase_name = phases_list[phase_index]
                    print(f"success: sub-goal={phase_index} ({phase_name})")
                    return phase_index

            bracket_match = re.search(r'\[\s*(\d+)\s*\]', response)
            if bracket_match:
                phase_index = int(bracket_match.group(1))
                if 0 <= phase_index < len(phases_list):
                    phase_name = phases_list[phase_index]
                    print(f"success: sub-goal={phase_index} ({phase_name})")
                    return phase_index

            print(f"fail-----: '{response}'")

        except Exception as e:
            print(f"Attempt {attempt + 1} error: {str(e)}")

    print("ALL fails!!!")
    return 0

def solve(question, args, to_print):
    x = question
    x = "Q: " + x + "\n" + "Steps: \n"
    print(x)
    max_round = args.max_round
    phases = generate_phase(question, args)
    zs = [x]
    infos = []
    answers = []
    to_check_answer = []
    round = 0

    while round < max_round:
        current_answers = []
        remained_zs = []
        goals = [get_current_subgoal(phases, z, args) for z in zs]
        skill_list = [generate_skills(goal, step, args) for goal, step in zip(goals, zs)]
        new_zs = [
            get_proposals(z, args, phase, skill)
            for z, phase, skill in zip(zs, goals, skill_list)
        ]
        # print('New thinking progress:', new_zs)
        new_zs = list(itertools.chain(*new_zs))

        selected_zs, not_selected_zs = duel_bandit_select(new_zs, args)
        print('Selected step:', selected_zs)

        if to_check_answer:
            selected_ans, not_selected_ans = duel_bandit_select(to_check_answer, args)
        else:
            selected_ans, not_selected_ans = [], []

        to_check_answer = []

        continue_zs = []
        for z in selected_zs:
            last_line = z.strip().split('\n')[-1]
            if ("###" in last_line):
                current_answers.append(z)
            else:
                continue_zs.append(z)

        for z in not_selected_zs:
            last_line = z.strip().split('\n')[-1]
            if ("###" in last_line):
                to_check_answer.append(z)
            else:
                remained_zs.append(z)

        answers = answers + current_answers + selected_ans
        to_check_answer = to_check_answer + not_selected_ans

        if len(continue_zs) < args.n_select_sample:
            length = min(len(remained_zs), args.n_select_sample - len(continue_zs))
            continue_zs = continue_zs + remained_zs[-length:]
            remained_zs = remained_zs[:-length]

        if to_print:
            print(f'Round {round}:')
            print(f'-- new_zs --: {len(new_zs)} solutions')
            print(f'-- continue_zs --: {len(continue_zs)} solutions')
            print(f'-- selected_zs --: {len(selected_zs)} solutions')
            print(f'-- current_answers --: {len(current_answers)} answers')
            print(f'-- answers --: {len(answers)} total answers')
            print(f'-- to_check_answer --: {len(to_check_answer)} pending answers')

        infos.append({
            'round': round,
            'x': x,
            'zs': zs,
            'new_zs': new_zs,
            'continue_zs': continue_zs,
            'selected_zs': selected_zs,
            'not_selected_zs': not_selected_zs,
            'current_answers': current_answers,
            'to_check_answers': to_check_answer,
            'answers': answers
        })
        round += 1

        zs = continue_zs
        if len(zs) == 0 and len(to_check_answer) == 0:
            break

    answer_list = []
    check_list = []
    pattern = r'###\s*([-+]?\d*\.?\d+)'
    for whole_answer in answers:
        match = re.search(pattern, whole_answer)
        if match:
            answer = match.group(1)
            answer_list.append(answer)
    
    for whole_answer in to_check_answer:
        match = re.search(pattern, whole_answer)
        if match:
            answer = match.group(1)
            check_list.append(answer)

    if len(answer_list):
        final_ans = Counter(answer_list).most_common(1)[0][0]
    elif len(check_list):
        final_ans = Counter(check_list).most_common(1)[0][0]
    else:
        final_ans = ""

    print('Final answer: ', final_ans)
    return final_ans, {'steps': infos}

