import itertools
from functools import partial
from models import gpt
import random
from collections import Counter, defaultdict
from game24 import get_current_numbers
import re
from typing import Tuple, List, Dict

def compare(task, x, pair, evaluate_times, backend):
    num = [get_current_numbers(pair[0]), get_current_numbers(pair[1])]
    values = []
    for i in range(evaluate_times):
        compare_prompt = task.compare_prompt_wrap(x, num, i)
        value_output = gpt(compare_prompt, model=backend)
        number = re.findall(r'\d+', value_output[0])
        if (len(number) and number[0] in [1,2]):
            values.append(number[0])
        else:
            values.append(random.randint(1, 2))
    counter = Counter(values)
    result = counter.most_common(1)[0][0]
    return int(result) - 1

def get_proposals(x, z, args, phase, skill):
    proposals = []
    for i in range(len(skill)):
        system_prompt = f'''You are a 24-game AI calculator. Strictly use each of these 4 numbers exactly once: {x} to get 24.

        **Current Numbers Available:** {z}
        **Current Sub-goal:** {phase}
        **Required Operations:** {skill[i]}

        **RULES:**
        1. MUST use all original numbers exactly once in final expression
        2. Only use numbers from current state
        3. Perform ACTUAL calculations each step
        4. Immediately correct calculation errors
        5. No theoretical explanations - only computations

        **STEP FORMAT:**
        (number) (operator) (number) = (result) [remaining: a,b,c]

        **FINAL ANSWER FORMAT:**
        ### (expression using all original numbers) = 24

        **EXAMPLE:**
        Input: 1 1 6 4
        Output: ### (6 * 4 * 1 * 1) = 24

        **REQUIREMENTS:**
        - Execute all required operations
        - Show calculations with remaining numbers
        - Final expression must contain all original numbers
        - Ensure mathematical correctness at each step
        - Never reuse or omit original numbers
        '''
        attempt = 0
        max_attempts = 3
        while attempt < max_attempts:
            try:
                proposal = gpt(system_prompt, model=args.backend)
                if proposal and proposal.strip():
                    proposals.append(proposal)
                    break
                else:
                    print(f"Empty response, retrying...")
                    attempt += 1
            except Exception as e:
                print(f" Error during GPT call: {e}, retrying...")
                attempt += 1
    return [z + _ + '\n' for _ in proposals]

def extract_key_phases(llm_output):
    patterns = [
        r'(?:\d+\.\s*)(.*?)(?=,\s*\n|\n\d+\.|$)',
        r'(?:\d+\.\s*)(.*?)(?=\n|$)',
        r'(?:[-*]\s*)(.*?)(?=\n|$)'
    ]
    phases = []
    for pattern in patterns:
        phases = re.findall(pattern, llm_output, re.MULTILINE)
        if phases:
            break
    cleaned_phases = []
    for phase in phases:
        phase = phase.strip()
        phase = re.sub(r'[,\.;]+$', '', phase).strip()
        if (phase and
                len(phase.split()) >= 3 and
                re.match(r'^[A-Za-z]', phase) and
                not re.search(r'[{}()\[\]<>]', phase)):
            cleaned_phases.append(phase)
    return cleaned_phases

def generate_phase(task, args, max_attempts=5):
    system_prompt = f'''You are an expert 24-point game strategist. Break down the problem into natural thinking phases based on how humans solve 24-point games.

    Given numbers: {task}

    You must give me the sub-goals according to the **Output Format:**.

    **Core Requirements:**
    - Generate several sub-goals that mirror natural human problem-solving
    - Each sub-goal should represent a distinct cognitive phase
    - Include brief explanations of the strategic purpose
    - Maintain logical progression from analysis to verification
   
    **Output Format:**
    1. [Strategic sub-goal],
    2. [Next phase],
    3. [Final verification]

    **Example Output:**
    1. Analyze numbers and explore groupings - Identify relationships and test pairing possibilities
    2. Plan operations and manage calculations - Determine sequence and ensure mathematical viability
    3. Validate and adjust solution - Verify results and refine approach to achieve 24

    **Critical Guidelines:**
    - Focus on cognitive processes, not specific calculations
    - Each sub-goal should represent a distinct thinking phase
    - Include concise purpose explanations (5-8 words)
    - Ensure logical progression from analysis to verification
    - Avoid mentioning specific numbers or operations
    '''

    attempts = 0
    while attempts < max_attempts:
        phase_output = gpt(system_prompt, model=args.backend)
        phases = extract_key_phases(phase_output)
        if phases:
            return phases
        attempts += 1
        print(f"Format validation failed. Attempt {attempts}/{max_attempts}. Retrying...")
    raise ValueError(f"Failed to generate properly formatted output after {max_attempts} attempts")


def extract_skills_list(skill_text, min_skills=4, max_skills=6):
    patterns = [
        r'^\s*[-*•→]\s*\[?(.+?)\]?\s*$',
        r'^\s*\d+\.\s*\[?(.+?)\]?\s*$',
        r'^\s*[a-z]\)\s*\[?(.+?)\]?\s*$',
        r'^\s*\[(.+?)\]\s*$'
    ]
    skills = []
    for pattern in patterns:
        found = re.findall(pattern, skill_text, re.MULTILINE)
        if found:
            skills.extend(found)
            if len(skills) >= max_skills:
                break
    if not (min_skills <= len(skills) <= max_skills):
        return None
    valid_skills = []
    for skill in skills:
        skill = skill.strip()
        if skill and len(skill) >= 2:
            skill = re.sub(r'^\s*\[|\]\s*$', '', skill)
            valid_skills.append(skill)
    return valid_skills if len(valid_skills) >= min_skills else None


def generate_skills(
        current_goal,
        current_step,
        args,
        min_skills=6,
        max_skills=8,
        max_attempts=5
):
    system_prompt = f'''You are a 24-Point Game Skill Generator. Generate {min_skills} to {max_skills} practical arithmetic methods for the current goal.

    **Current step:** {current_step}
    **Goal:** {current_goal}

    Generate {min_skills} to {max_skills} arithmetic methods to achieve:
    Goal: "{current_goal}" based on Current step: {current_step}

    **Skill Requirements:**
    - Each skill must use valid arithmetic operations (+, -, ×, ÷)
    - Skills should directly achieve the stated goal
    - Focus on strategic calculation approaches
    - Ensure mathematical validity (no division by zero)
    - *****  Never try to get zero

    **Output Format:**
    - [Skill 1]
    - [Skill 2]
    ...
    - [Skill n]

    **Critical Rules:**
    - Output ONLY bullet-pointed skills
    - NO additional text or explanations
    - Skills must start with action verbs
    - Keep each skill under 15 words
    - *****  Never try to get zero

    === 24-POINT GOAL-SPECIFIC EXAMPLES ===

    **Example: Number Analysis & Grouping Strategy**
    Current step: "Numbers: 6, 6, 4, 2"
    Goal: "Analyze numbers and explore groupings"
    Skills:
    - Multiply 6 and 4 to get 24 directly
    - Multiply 6 and 2 to get 12 for later use
    - Add 6 and 6 to get 12 as base value
    - Divide 6 by 2 to get 3 for multiplication
    '''

    for attempt in range(max_attempts):
        try:
            skills_output = gpt(system_prompt, model=args.backend)
            skills = extract_skills_list(skills_output, min_skills, max_skills)
            if len(skills) >= min_skills and len(skills) <= max_skills:
                return skills
            print(
                f"Attempt {attempt + 1}: Generated {len(skills) if skills else 0} skills (need {min_skills}-{max_skills})")
        except Exception as e:
            print(f"Attempt {attempt + 1} failed with error: {str(e)}")

def duel_bandit_select(task, x, pool, args, n=4, m=4, final_rounds=1):
    if not pool:
        return [], []

    if len(pool) == 1:
        return pool, []

    original_pool = pool.copy()
    current_pool = pool.copy()

    while len(current_pool) > m:
        edges = []
        for i in range(len(current_pool)):
            candidates = [j for j in range(len(current_pool)) if j != i]
            opponents = random.sample(candidates, min(n, len(candidates)))
            edges.extend([(i, opp) for opp in opponents])

        wins = defaultdict(int)
        comparisons = defaultdict(int)
        for i, j in edges:
            pair = [current_pool[i], current_pool[j]]
            winner = compare(task, x, pair, evaluate_times=1, backend=args.backend)
            wins[winner] += 1
            comparisons[i] += 1
            comparisons[j] += 1

        new_pool = []
        for idx in range(len(current_pool)):
            score = wins.get(idx, 0) / max(1, comparisons.get(idx, 1))
            if score >= 0.5:
                new_pool.append(current_pool[idx])

        if len(new_pool) == len(current_pool):
            new_pool = random.sample(new_pool, min(m, len(new_pool)))

        current_pool = new_pool

    final_scores = defaultdict(float)
    for _ in range(final_rounds):
        edges = [(i, j) for i in range(len(current_pool))
                 for j in range(i + 1, len(current_pool))]

        for i, j in edges:
            pair = [current_pool[i], current_pool[j]]
            winner = compare(task, x, pair, evaluate_times=1, backend=args.backend)
            final_scores[winner] += 1

    if not final_scores:
        selected_final = current_pool
    else:
        max_score = max(final_scores.values(), default=0)
        selected_indices = [i for i in final_scores if final_scores[i] == max_score]
        selected_final = [current_pool[i] for i in selected_indices]

    selected_set = set(selected_final)
    not_selected = [sol for sol in original_pool if sol not in selected_set]

    return selected_final, not_selected


def get_current_subgoal(
        phases_list,
        current_context,
        args,
        max_attempts=3
):
    system_prompt = f'''You are a 24-Point Game Reasoner. Analyze current calculation state and select the next sub-goal.

    **Sub-goal List:** {phases_list}
    **Current Calculation State:** {current_context}

    **Selection Rules:**
    - Choose the next logical sub-goal from the provided list
    - Skip goals that are already completed or mostly achieved
    - Return ONLY the index number of the selected sub-goal
    - No explanations or additional text

    **Goal Completion Detection:**
    - If numbers are still ungrouped and unanalyzed → choose analysis goal (index 0)
    - If numbers are grouped but operations not planned → choose operation planning goal (index 1)  
    - If expression is formed but not verified → choose verification goal (index 2)
    - If current step shows near-completion of a goal → skip to next goal

    **Output Format:**
    [index]

    **24-Point Examples:**

    **Example: Initial Analysis Phase**
    Sub-goals: ["Analyze numbers and explore groupings", "Plan operations and manage calculations", "Validate and adjust solution"]
    Current: "Numbers: 3, 4, 6, 8"
    Output: 0

    **Critical Instructions:**
    - Output must be single index number only
    - Select from provided sub-goal list based on current progress
    - Choose next uncompleted goal in logical sequence
    - Skip goals that are substantially completed
    - No additional text or explanations
    '''

    for attempt in range(max_attempts):
        try:
            response = gpt(system_prompt, model=args.backend)
            response = response.strip()
            match = re.search(r'^\s*(\d+)\s*$', response)
            if match:
                phase_index = int(match.group(1))
                if 0 <= phase_index < len(phases_list):
                    phase_name = phases_list[phase_index]
                    print(f"success: sub-goal={phase_index} ({phase_name})")
                    return phase_index
            bracket_match = re.search(r'\[\s*(\d+)\s*\]', response)
            if bracket_match:
                phase_index = int(bracket_match.group(1))
                if 0 <= phase_index < len(phases_list):
                    phase_name = phases_list[phase_index]
                    print(f"success: sub-goal={phase_index} ({phase_name})")
                    return phase_index

            print(f"fail-----: '{response}'")

        except Exception as e:
            print(f"Attempt {attempt + 1} error: {str(e)}")

    print("ALL fails!!!")
    return 0


def solve(args, task, idx, to_print=True):
    x = task.get_input(idx)
    print("idx:", idx + 1)
    print("Task:", x)

    phases = generate_phase(x, args)

    ys = ['']
    remained_ys = []
    infos = []
    round = 0
    answers = []

    while round < args.max_round:
        round += 1

        goals = [get_current_subgoal(phases, y, args) for y in ys]
        skill_list = [generate_skills(goal, step, args) for goal, step in zip(goals, ys)]

        new_ys = [
            get_proposals(x, y, args, goal, skill)
            for y, goal, skill in zip(ys, goals, skill_list)
        ]
        new_ys = list(itertools.chain(*new_ys))


        # print('New thinking progress:', new_ys)

        continue_ys = []
        for y in new_ys:
            last_line = y.strip().split('\n')[-1]
            if 'left: ' in last_line:
                nums = last_line.split('left: ')[-1].split(')')[0]
                numlist = nums.split()
                if len(numlist) != 1 or nums == '24':
                    continue_ys.append(y)
            elif '###' in last_line:
                continue_ys.append(y)

        if len(continue_ys) < 2 * args.n_select_sample:
            length = min(len(remained_ys), 2 * args.n_select_sample - len(continue_ys))
            continue_ys = continue_ys + remained_ys[-length:]
            remained_ys = remained_ys[:-length]

        selected_ys, not_selected_ys = duel_bandit_select(
            task, x, continue_ys, args
        )
        print('------selected', selected_ys)
        remained_ys = remained_ys + not_selected_ys

        temp = []
        for y in selected_ys:
            last_line = y.strip().split('\n')[-1]
            if '###' in last_line:
                if '24' in last_line:
                    answers.append(y)
            else:
                temp.append(y)
        selected_ys = temp

        if to_print:
            print(f'-- new_ys --: {new_ys}\n-- choices --: {continue_ys}\n')
            print("round:", round)

        infos.append({
            'round': round,
            'x': x,
            'ys': ys,
            'new_ys': new_ys,
            'select_ys': selected_ys,
        })

        ys = selected_ys
        if len(ys) == 0:
            ys = ['']

        if len(answers):
            break

    if len(answers):
        print("Get answer.")
        final_answers = []
        for ans in answers:
            lines = ans.strip().split('\n')
            for line in reversed(lines):
                if line.startswith('###'):
                    final_expr = line.split('###')[-1].strip()
                    final_answers.append(final_expr)
                    break
            else:
                final_answers.append(ans.strip())

        return final_answers, {'steps': infos}
    else:
        return [""], {'steps': infos}