import json
import os
import subprocess
import tempfile
from itertools import combinations
from collections import defaultdict
import pdb
import numpy as np
import re
from openai import OpenAI
import cyaron as cy
import time
import ast
import tempfile
import subprocess
import sys

LOG_FILE = "api_usage.log"

def log_usage(model_name, usage):
    log_entry = {
        "model_name": model_name,
        "usage": {
            "completion_tokens": usage.completion_tokens,
            "prompt_tokens": usage.prompt_tokens,
            "total_tokens": usage.total_tokens
        }
    }
    with open(LOG_FILE, 'a', encoding='utf-8') as f:
        f.write(json.dumps(log_entry, ensure_ascii=False) + '\n')

def fetch_gpt4_tong(query, model_name = "o4-mini-2025-04-16"):
    print('fetching gpt ...')
    
    client = OpenAI(api_key='',
                base_url = '')
    
    completion = client.chat.completions.create(
        model=model_name,
        messages=query,
    )
    res = completion.choices[0].message.content
    print(res)

    usage = completion.usage

    print(usage)
    print(model_name)
    print("Completion tokens:", usage.completion_tokens)
    log_usage(model_name, usage)

    return res

def normalize_code(code):
    try:
        return ast.unparse(ast.parse(code))
    except:
        return code.strip()
    
def merge_input_files(file_paths):
    merged_map = {}

    for file_path in file_paths:
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                if not line.strip():
                    continue
                data = json.loads(line)
                key = data.get('question')
                if key:
                    current_inputs = set(data.get('input_string', []))
                    if key in merged_map:
                        merged_map[key].update(current_inputs)
                    else:
                        merged_map[key] = current_inputs

    for key in merged_map:
        merged_map[key] = list(merged_map[key])

    return merged_map

def gen_random_input_generator(new_problems):
    with open('./prompts/random_input_generator.txt', 'r', encoding='utf-8') as file:  
        content = file.read()

    query = [{
        "role": "user", "content": 
        content + 
        f"Here is the problem description: {new_problems}\n"
    }]
    ans = safe_fetch_gpt4(query)
    return ans

def gen_adv_input_generator(new_problems):
    with open('./prompts/adversial_input_generator.txt', 'r', encoding='utf-8') as file:  
        content = file.read()

    query = [{
        "role": "user", "content": 
        content + 
        f"Here is the problem description: {new_problems}\n"
    }]
    ans = safe_fetch_gpt4(query)
    return ans

def parse_ans(ans):
    part2_marker = "Part 2: Code for Test Input Generation"
    part3_marker = "Part 3: Code to Validate Test Input"
    
    part2_start = ans.find(part2_marker)
    part3_start = ans.find(part3_marker)
    
    if part2_start == -1 or part3_start == -1:
        raise ValueError("Invalid answer format: Missing Part 2 or Part 3 markers")
    
    part2_code_start = ans.find("\n", part2_start) + 1
    part2_code = ans[part2_code_start:part3_start].strip()
    
    part3_code_start = ans.find("\n", part3_start) + 1
    part3_code = ans[part3_code_start:].strip()
    
    return part2_code, part3_code

def find_equivalent_groups(all_outputs, outputs_match, input_strings):
    groups = []
    code_to_group = {}
    
    for i in range(len(all_outputs)):
        if i in code_to_group:
            continue
        
        current_group = [i]
        
        for j in range(i + 1, len(all_outputs)):
            match = True
            for k in range(len(input_strings)):
                if not outputs_match(all_outputs[i][k], all_outputs[j][k]):
                    match = False
                    break
            
            if match:
                current_group.append(j)
                code_to_group[j] = len(groups)
        
        groups.append(current_group)
        code_to_group[i] = len(groups) - 1
    
    matching_pairs = []
    for group in groups:
        if len(group) >= 2:
            matching_pairs.extend(combinations(group, 2))
    
    return matching_pairs, groups


def run_code(code_str, input_str="", timeout=5, workdir=None):
    if workdir is None:
        workdir = os.path.join(os.getcwd(), "shahe_env")
    os.makedirs(workdir, exist_ok=True)
    try:
        with tempfile.NamedTemporaryFile(mode='w+', suffix='.py', delete=False, dir=workdir) as tmp_file:
            tmp_file.write(code_str)
            tmp_file.flush()
            tmp_name = tmp_file.name
        command = [sys.executable, tmp_name]
        process = subprocess.Popen(
            command,
            stdin=subprocess.PIPE,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            text=True
        )
        try:
            stdout, stderr = process.communicate(input=input_str, timeout=timeout)
            if process.returncode != 0:
                error_msg = stderr.strip()
                if "Traceback" in error_msg:
                    error_lines = error_msg.split('\n')
                    if len(error_lines) >= 2:
                        return f"Runtime Error: {error_lines[-2]}\n{error_lines[-1]}"
                return f"Execution Error: {error_msg}"
            return stdout.strip()
        except subprocess.TimeoutExpired:
            process.kill()
            return "Timeout: Execution exceeded time limit"
        finally:
            try:
                os.unlink(tmp_name)
            except:
                pass
    except SyntaxError as e:
        return f"Syntax Error: {str(e)}"
    except Exception as e:
        return f"System Error: {str(e)}"


def is_valid_output(output):
    return (output is not None and 
            output != "Timeout" and 
            not output.startswith(("Execution Error:", "Runtime Error:", 
                                  "Syntax Error:", "System Error:", 
                                  "Timeout:")))

def outputs_match(output1, output2):
    if not is_valid_output(output1) or not is_valid_output(output2):
        return False
    return compare_std_results(output1, output2)

def compare_std_results(exec_outputs, outputs, debug=False):
    def normalize_number_str(s):
        try:
            f = float(s)
            if abs(f) < 1e-10:
                return "0.0"
            if abs(f) >= 10000:
                return f"{f:.6g}"
            elif abs(f) >= 1:
                return f"{f:.8g}"
            else:
                return f"{f:.12g}"
        except:
            return str(s).strip()

    def normalize_lines(obj):
        if isinstance(obj, str):
            return [line.strip() for line in obj.strip().splitlines() if line.strip()]
        elif isinstance(obj, (list, tuple)):
            lines = []
            for item in obj:
                if isinstance(item, str):
                    lines.extend([l.strip() for l in item.strip().splitlines() if l.strip()])
                else:
                    lines.append(str(item).strip())
            return lines
        return [str(obj).strip()]

    def comparable_form(lines):
        return [normalize_number_str(x) for x in lines]

    def deep_compare(a, b, tolerance=1e-5):
        if isinstance(a, (list, tuple)) and isinstance(b, (list, tuple)):
            if len(a) != len(b):
                return False
            return all(deep_compare(x, y, tolerance) for x, y in zip(a, b))

        if isinstance(a, dict) and isinstance(b, dict):
            if set(a.keys()) != set(b.keys()):
                return False
            return all(deep_compare(a[k], b[k], tolerance) for k in a)

        try:
            a_str = normalize_number_str(a)
            b_str = normalize_number_str(b)
            return a_str == b_str
        except:
            pass

        return str(a).strip() == str(b).strip()

    if isinstance(exec_outputs, str) and isinstance(outputs, (list, tuple)):
        return False
    if isinstance(outputs, str) and isinstance(exec_outputs, (list, tuple)):
        return False

    norm_exec = comparable_form(normalize_lines(exec_outputs))
    norm_outputs = comparable_form(normalize_lines(outputs))
    if debug:
        print("Normalized Exec   :", norm_exec)
        print("Normalized Outputs:", norm_outputs)
    if norm_exec == norm_outputs:
        return True

    try:
        parsed_exec = ast.literal_eval(exec_outputs) if isinstance(exec_outputs, str) else exec_outputs
        parsed_output = ast.literal_eval(outputs) if isinstance(outputs, str) else outputs
        if deep_compare(parsed_exec, parsed_output):
            return True
    except:
        pass

    try:
        exec_arr = np.array(exec_outputs, dtype=float)
        output_arr = np.array(outputs, dtype=float)
        if np.allclose(exec_arr, output_arr, atol=1e-5, rtol=1e-3):
            return True
    except:
        pass

    if isinstance(exec_outputs, list) and isinstance(outputs, list):
        stripped_exec = [x.strip() if isinstance(x, str) else x for x in exec_outputs]
        stripped_out = [x.strip() if isinstance(x, str) else x for x in outputs]
        if stripped_exec == stripped_out:
            return True
        try:
            num_exec = [float(x) if isinstance(x, str) else x for x in stripped_exec]
            num_out = [float(x) if isinstance(x, str) else x for x in stripped_out]
            if deep_compare(num_exec, num_out):
                return True
        except:
            pass

    return False

def extract_code(model_output: str):
    outputlines = model_output.split("\n")
    indexlines = [i for i, line in enumerate(outputlines) if "```" in line]
    if len(indexlines) < 2:
        return ""
    return "\n".join(outputlines[indexlines[-2] + 1 : indexlines[-1]])

def safe_fetch_gpt4(messages, max_retries=3):
    for attempt in range(max_retries):
        try:
            return fetch_gpt4_tong(messages)
        except Exception as e:
            print(f"Attempt {attempt + 1} failed: {str(e)}")
            if attempt == max_retries - 1:
                raise
            time.sleep(5)
    return None

def fetch_r1(query):
    print('fetching gpt ...')
    client = OpenAI(api_key='',
                    base_url = '')
    completion = client.chat.completions.create(
        model="deepseek-r1",
        messages=query,
        temperature=0.2
    )
    res = completion.choices[0].message.content
    print(res)
    return res