import pdb
import random
from openai import OpenAI
import json
from datasets import load_dataset
import re
from pathlib import Path
from collections import defaultdict
import json
import concurrent.futures
import threading
import os
from typing import Tuple

MAX_FETCH_WORKERS = 8
write_lock = threading.Lock()

def fetch_gpt4(query, model_name = "o4-mini-2025-04-16"):
    print('fetching gpt ...')
    client = OpenAI(api_key='',
                base_url = '')
    
    completion = client.chat.completions.create(
        model=model_name,
        messages=query,
    )
    res = completion.choices[0].message.content
    print(res)
    return res

def get_skills_by_tag(file_path, tag_name):
    tag_name_norm = re.sub(r'\s+', ' ', tag_name.strip().lower())

    with open(file_path, encoding='utf-8') as f:
        tags_data = json.load(f)

    for tag in tags_data:
        curr_norm = re.sub(r'\s+', ' ', tag['tag'].strip().lower())
        if curr_norm == tag_name_norm:
            return tag['skills']

    return None

def gen_new_prob(question1, tags1, solution1, question2, tags2, solution2, skills):

    with open('./prompts/gen_new_questions_pairs_same_10.txt', 'r', encoding='utf-8') as file:  
        content = file.read()

    query = {
        "role": "user", "content": 
        content + 
        f"Here are two coding problems:\n\n"
        f"Problem 1:\n{question1}\n"
        f"Tags: {tags1}\n"
        f"Solution: {solution1}\n\n"
        f"Problem 2:\n{question2}\n"
        f"Tags: {tags2}\n"
        f"Solution: {solution2}\n\n"
        f"Please generate a completely new coding problem \n\n The tags and skills can be selected from {skills} \n"
    }

    ans = fetch_gpt4([query])

    print(ans)
    return ans


def parse_data(output_text):
    result = {
        'orig-analysis': '',
        'new_problem': '',
        'example-test cases': [],
        'tags': [],
        'skills': [],
        'difficulty': ''
    }
    
    current_part = None
    lines = output_text.split('\n')
    collecting_example = None
    
    for line in lines:
        line = line.strip()
        
        if line.startswith('## Part 1:'):
            current_part = 'orig-analysis'
            continue
        elif line.startswith('## Part 2:'):
            current_part = 'new_problem'
            continue
        elif line.startswith('## Part 3:'):
            current_part = 'example-test cases'
            continue
        elif line.startswith('## Part 4:'):
            current_part = 'metadata'
            continue
            
        if current_part == 'orig-analysis':
            if line and not line.startswith('Step'):
                result['orig-analysis'] += line + '\n'
        elif current_part == 'new_problem':
            if line.startswith('New_problem:'):
                result['new_problem'] = line[len('New_problem:'):].strip() + '\n'
            elif result['new_problem'] is not None:
                if not line.startswith('## Part'):
                    result['new_problem'] += line + '\n'
        elif current_part == 'example-test cases':
            if line.startswith('Input:'):
                input_case = line[len('Input:'):].strip()
                result['example-test cases'].append({'input': input_case, 'output': '', 'explanation': ''})
                collecting_example = 'input'
            elif line.startswith('Output:'):
                if result['example-test cases']:
                    result['example-test cases'][-1]['output'] = line[len('Output:'):].strip()
                    collecting_example = 'output'
            elif line.startswith('Explanation:'):
                if result['example-test cases']:
                    result['example-test cases'][-1]['explanation'] = line[len('Explanation:'):].strip()
                    collecting_example = 'explanation'
            elif result['example-test cases'] and collecting_example:
                if collecting_example == 'input':
                    result['example-test cases'][-1]['input'] += '\n' + line
                elif collecting_example == 'output':
                    result['example-test cases'][-1]['output'] += '\n' + line
                elif collecting_example == 'explanation':
                    result['example-test cases'][-1]['explanation'] += '\n' + line
        elif current_part == 'metadata':
            if line.startswith('difficulty:'):
                result['difficulty'] = line[len('difficulty:'):].strip()
            elif line.startswith('tags:'):
                points = line[len('tags:'):].strip()
                result['tags'] = [p.strip() for p in points.split(',')]
            elif line.startswith('skills:'):
                skills = line[len('skills:'):].strip()
                result['skills'] = [s.strip() for s in skills.split(',')]
    
    if result['orig-analysis']:
        result['orig-analysis'] = result['orig-analysis'].strip()
    if isinstance(result['new_problem'], str):
        result['new_problem'] = result['new_problem'].strip()

    return result

def load_existing_pairs(output_file: str):
    existing = set()
    if not os.path.exists(output_file):
        return existing
    try:
        with open(output_file, "r", encoding="utf-8") as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue
                try:
                    obj = json.loads(line)
                    q1 = obj.get('question1', {}).get('question', None)
                    q2 = obj.get('question2', {}).get('question', None)
                    if q1 is not None and q2 is not None:
                        key = (q1.strip(), q2.strip())
                        existing.add(key)
                except Exception:
                    continue
    except Exception:
        return set()
    return existing

def worker_and_write(sampled_pair_info: dict, output_file: str, existing_pairs: set):
    q1_text = sampled_pair_info['question1_text'].strip()
    q2_text = sampled_pair_info['question2_text'].strip()
    key = (q1_text, q2_text)
    if key in existing_pairs:
        print("skip pair:", key[0][:40], "...", key[1][:40])
        return None

    try:
        ans = gen_new_prob(
            sampled_pair_info['question1_text'],
            sampled_pair_info['tags1'],
            sampled_pair_info['sol1'],
            sampled_pair_info['question2_text'],
            sampled_pair_info['tags2'],
            sampled_pair_info['sol2'],
            sampled_pair_info['skills']
        )
    except Exception as e:
        print("fetch failed:", e)
        return None

    if ans is None:
        print("LLM return None")
        return None

    output_data = { 
        'question1': {
            'question': sampled_pair_info['question1_text'],
            'question_index': sampled_pair_info['indices'][0],
            'tags': sampled_pair_info['tags1'],
            'solution_index': sampled_pair_info['idx1'],
            'solution': sampled_pair_info['sol1'],
            'difficulty': sampled_pair_info['difficulty1'],
        },
        'question2': {
            'question': sampled_pair_info['question2_text'],
            'question_index': sampled_pair_info['indices'][1],
            'tags': sampled_pair_info['tags2'],
            'solution_index': sampled_pair_info['idx2'],
            'solution': sampled_pair_info['sol2'],
            'difficulty': sampled_pair_info['difficulty2'],
        },
        'new_prob': ans
    }

    with write_lock:
        if key in existing_pairs:
            print("duplicate found:", key[0][:40], "...", key[1][:40])
            return None
        try:
            with open(output_file, "a", encoding="utf-8") as file:
                json.dump(output_data, file, ensure_ascii=False)
                file.write("\n")
            existing_pairs.add(key)
            print("write success:", output_file, " pair preview:", key[0][:60], " / ", key[1][:60])
            return output_data
        except Exception as e:
            print("write failed:", e)
            return None

def get_orig_questions(tag_name):
    dataset = load_dataset("BAAI/TACO", split="train").shuffle(seed=0)
    
    filtered_dataset = dataset.filter(
        lambda example: (
            tag_name in example["tags"] and
            example['tags'] != '[]' and
            example['solutions'] != '[]'
        )
    )
    
    if len(filtered_dataset) < 2:
        print(f"skip tag {tag_name}")
        return
    
    subset_size = min(50, len(filtered_dataset))
    subset = filtered_dataset.select(range(subset_size))
    data_list = list(subset)

    none_num = 0

    output_file = f"./new_prob/new_prob_{tag_name.replace(' ', '_')}.json"
    os.makedirs(os.path.dirname(output_file), exist_ok=True)
    N_new = 40
    file_path = 'tags_sample_5.json'

    existing_pairs = load_existing_pairs(output_file)
    print(f"found {len(existing_pairs)} existing records")

    sampled_pairs = []
    sampled_set = set()
    attempts = 0
    max_attempts = N_new * 10 + 200
    while len(sampled_pairs) < N_new and attempts < max_attempts:
        attempts += 1
        if len(data_list) < 2:
            break
        indices = random.sample(range(len(data_list)), 2)
        i0, i1 = indices[0], indices[1]
        q1_text = data_list[i0]['question'].strip()
        q2_text = data_list[i1]['question'].strip()
        key = (q1_text, q2_text)
        if key in sampled_set:
            continue
        if key in existing_pairs:
            continue

        try:
            solutions1 = json.loads(data_list[i0]['solutions'])
            idx1 = random.randrange(len(solutions1))
            sampled_sol1 = solutions1[idx1]
        except Exception:
            continue
        try:
            solutions2 = json.loads(data_list[i1]['solutions'])
            idx2 = random.randrange(len(solutions2))
            sampled_sol2 = solutions2[idx2]
        except Exception:
            continue

        sampled_info = {
            'indices': (i0, i1),
            'idx1': idx1,
            'idx2': idx2,
            'question1_text': data_list[i0]['question'],
            'question2_text': data_list[i1]['question'],
            'tags1': data_list[i0]['tags'],
            'tags2': data_list[i1]['tags'],
            'sol1': sampled_sol1,
            'sol2': sampled_sol2,
            'difficulty1': data_list[i0].get('difficulty', ''),
            'difficulty2': data_list[i1].get('difficulty', ''),
            'skills': get_skills_by_tag(file_path, tag_name)
        }

        sampled_pairs.append(sampled_info)
        sampled_set.add(key)

    print(f"generated {len(sampled_pairs)} pairs for tag {tag_name}")

    with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_FETCH_WORKERS) as executor:
        futures = []
        for pair_info in sampled_pairs:
            futures.append(executor.submit(worker_and_write, pair_info, output_file, existing_pairs))

        for fut in concurrent.futures.as_completed(futures):
            try:
                res = fut.result()
                if res is None:
                    continue
            except Exception as e:
                print("task exception:", e)

    print(f'Processed tag {tag_name} and saved results to {output_file}')

def get_unique_tags():
    dataset = load_dataset("BAAI/TACO", split="train")
    tag_counts = defaultdict(int)
    
    for example in dataset:
        if 'tags' in example and example['tags']:
            tags = example['tags'] if isinstance(example['tags'], list) else eval(example['tags'])
            for tag in tags:
                tag_counts[tag] += 1
    
    sorted_tags = sorted(tag_counts.items(), key=lambda x: x[1], reverse=True)
    
    return [tag for tag, count in sorted_tags]

def main():
    unique_tags = get_unique_tags()
    print(f"found {len(unique_tags)} unique tags:")
    print(unique_tags)

    unique_tags = ['Graph algorithms', 'Tree algorithms', 'Ad-hoc', 'Game theory', 'Geometry']
    
    for tag in unique_tags:
        print(f"\nprocessing tag: {tag}")
        get_orig_questions(tag)

if __name__ == "__main__":
    main()