import json
import argparse
import logging
import os
import subprocess
from math_verify import parse, verify


logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

script_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.abspath(os.path.join(script_dir, os.pardir))


def read_json_fields(filename):
    try:
        with open(filename, 'r') as file:
            data = json.load(file)
        return data
    except FileNotFoundError:
        logging.error("The file was not found.")
    except json.JSONDecodeError:
        logging.error("There was an error decoding the JSON file.")
    except Exception as e:
        logging.error(f"An error occurred: {e}")


def read_json_objects(filename, field_names=None):
    file_extension = os.path.splitext(filename)[1]
    if file_extension == '.jsonl':
        try:
            with open(filename, 'r') as file:
                lines = file.readlines()
            items = []
            for line in lines:
                item = json.loads(line)
                if field_names is not None and isinstance(field_names, list):
                    new_item = {}
                    for field_name in item:
                        new_item[field_name] = item[field_name]
                items.append(item)
            return items
        except FileNotFoundError:
            logging.error("The file was not found.")
        except json.JSONDecodeError:
            logging.error("There was an error decoding the JSONL file.")
        except Exception as e:
            logging.error(f"An error occurred: {e}")
    elif file_extension == '.json':
        try:
            with open(filename, 'r') as file:
                data = json.load(file)
            items = []
            for item in data:
                items.append(item)
            return items
        except FileNotFoundError:
            logging.error("The file was not found.")
        except json.JSONDecodeError:
            logging.error("There was an error decoding the JSON file.")
        except Exception as e:
            logging.error(f"An error occurred: {e}")
    else:
        logging.error(f"Unknown file extension {file_extension}")
        return []


def write_data_to_json_file(data, file_path):
    try:
        with open(file_path, 'w') as file:
            json.dump(data, file, ensure_ascii=False, indent=4)
        logging.info(f"Data successfully written to {file_path}")
    except Exception as e:
        logging.error(f"An error occurred: {e}")


def create_parent_directory(file_path):
    """
    Creates the parent directories of a given file path if they do not exist.

    Args:
        file_path (str or Path): The path to the file.
    """
    from pathlib import Path
    file_path = Path(file_path)  # Ensure it's a Path object
    parent_directory = file_path.parent

    # Create parent directories recursively if they don't exist, and ignore if they already exist
    parent_directory.mkdir(parents=True, exist_ok=True)
    print(f"Parent directory '{parent_directory}' ensured to exist.")


def run_cmd(cmd):
    try:
        p = subprocess.Popen(
            cmd,
            stdout=subprocess.PIPE,
            stderr=subprocess.STDOUT,  # Merge stderr into stdout
            shell=True,
            universal_newlines=True  # Ensure output is in text mode
        )
        
        error_detected = False
        error_keywords = [
            "ERROR",
            "Error",
            "error"
            "Unrecognized model",
            "failed",
            "exception",
            "Traceback"
        ]
        
        # Read output in real-time and detect errors
        while True:
            line = p.stdout.readline()
            if not line:
                break
            logging.info(line.rstrip())  # Log normally
            
            # Check if any error keywords are present
            if any(keyword.lower() in line.lower() for keyword in error_keywords):
                error_detected = True
                logging.error(f"Detected error in output: {line.strip()}")
        
        # Wait for process to finish
        returncode = p.wait()
        
        # If errors were detected or return code is non-zero, return False
        if error_detected or returncode != 0:
            logging.error(f"Command failed (returncode={returncode}, errors detected)")
            return False
        
        return True  # Return True indicates success
        
    except Exception as e:
        logging.error(f"Unexpected error running command: {e}")
        return False


def math_verify(llm_res, ground_truth):
    llm_answer = parse(llm_res)
    ground_truth_answer = parse(ground_truth)
    correct = verify(llm_answer, ground_truth_answer)
    return correct
 
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', type=str, required=True, help='path to the model directory')
    parser.add_argument('--input_path', type=str, default='data/gsm8k/test.jsonl', help='path to the input file')
    parser.add_argument('--output_path', type=str, required=True, help='path to the ouptut file')
    args = parser.parse_args()

    config_template = {
        "job_type": "eval_model_hf",
        "dataset": {
        "input_path": args.input_path,
        "output_path": args.output_path,
        "seed": 42
        },
        "inference":{
        "model": args.model,
        "batch_size": 8,
        "dp_size": 8,
        "tp_size": 1,
        "enable_chunked_prefill": True,
        "seed": 777,
        "gpu_memory_utilization": 0.8,
        "temperature": 0.2,
        "trust_remote_code": True,
        "enforce_eager": True,
        "max_new_tokens": 1024
        }
    }
    temp_dir = '/tmp/ladders_of_thought/eval_model/'
    temp_config_path = os.path.join(temp_dir, 'tmp.json')
    create_parent_directory(temp_config_path)
    write_data_to_json_file(config_template, temp_config_path)


    cmd = [
        'accelerate', 'launch',
        '--config_file', 'configs/accelerate_config/muti_gpu.yaml',
        os.path.join(script_dir, f'infer.py'),
        '--config', temp_config_path
    ]  
    cmd = ' '.join(cmd)
    logging.info(f"Running command: {cmd}")
    if not run_cmd(cmd):
        logging.info(f"Infer failed")
        exit()
    
    if not os.path.exists(args.output_path):
        logging.info(f"Output file {args.output_path} does not exist")
        exit()
    
    data_list = read_json_objects(args.output_path)
    total = 0
    correct = 0
    for item in data_list:
        if math_verify(item['output'], item['answer']):
            correct += 1
        total += 1
    logging.info(f"Total: {total}, correct: {correct}, Math-verify accuracy: {float(correct) / total * 100} %")
    

if __name__ == "__main__":
    main()