import os
import json
import re
import black
import subprocess
from pathlib import Path
import shutil
import uuid
from verl.utils.reward_score.interpreter import Interpreter
import math
import ray
from verl.workers.rollout.vllm_rollout.vllm_rollout import vLLMRollout


LEADERBOARD = json.load(open('/workdir/leaderboard.json'))

_vllm_engine = None

def get_vllm_engine():
    """Get or initialize the vLLM engine for inference."""
    global _vllm_engine
    if _vllm_engine is None:
        from transformers import AutoTokenizer, AutoConfig
        model_path = "/workdir/Qwen2.5-3B-Instruct"
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model_config = AutoConfig.from_pretrained(model_path)
        from transformers import AutoModelForCausalLM
        model = AutoModelForCausalLM.from_pretrained(model_path)
        config = {
            "tensor_model_parallel_size": 1,
            "prompt_length": 1024,
            "response_length": 1024,
            "dtype": "float16",
            "enforce_eager": True,
            "gpu_memory_utilization": 0.4,
            "free_cache_engine": False,
            "load_format": "auto",
            "n": 1
        }

        _vllm_engine = vLLMRollout(
            actor_module=model,
            config=DictConfig(config),
            tokenizer=tokenizer,
            model_hf_config=model_config
        )
    return _vllm_engine


def wrap_code(code: str, lang="python") -> str:
    """Wraps code with three backticks."""
    return f"```{lang}\n{code}\n```"


def is_valid_python_script(script):
    """Check if a script is a valid Python script."""
    try:
        compile(script, "<string>", "exec")
        return True
    except SyntaxError:
        return False


def extract_jsons(text):
    """Extract all JSON objects from the text. Caveat: This function cannot handle nested JSON objects."""
    json_objects = []
    matches = re.findall(r"\{.*?\}", text, re.DOTALL)
    for match in matches:
        try:
            json_obj = json.loads(match)
            json_objects.append(json_obj)
        except json.JSONDecodeError:
            pass

    # Sometimes chatgpt-turbo forget the last curly bracket, so we try to add it back when no json is found
    if len(json_objects) == 0 and not text.endswith("}"):
        json_objects = extract_jsons(text + "}")
        if len(json_objects) > 0:
            return json_objects

    return json_objects


def trim_long_string(string, threshold=5100, k=2500):
    # Check if the length of the string is longer than the threshold
    if len(string) > threshold:
        # Output the first k and last k characters
        first_k_chars = string[:k]
        last_k_chars = string[-k:]

        truncated_len = len(string) - 2 * k

        return f"{first_k_chars}\n ... [{truncated_len} characters truncated] ... \n{last_k_chars}"
    else:
        return string


def extract_code(text):
    """Extract python code blocks from the text."""
    parsed_codes = []

    # When code is in a text or python block
    matches = re.findall(r"```(python)?\n*(.*?)\n*```", text, re.DOTALL)
    for match in matches:
        code_block = match[1]
        parsed_codes.append(code_block)

    # When the entire text is code or backticks of the code block is missing
    if len(parsed_codes) == 0:
        matches = re.findall(r"^(```(python)?)?\n?(.*?)\n?(```)?$", text, re.DOTALL)
        if matches:
            code_block = matches[0][2]
            parsed_codes.append(code_block)

    # validate the parsed codes
    valid_code_blocks = [
        format_code(c) for c in parsed_codes if is_valid_python_script(c)
    ]
    return format_code("\n\n".join(valid_code_blocks))


def extract_text_up_to_code(s):
    """Extract (presumed) natural language text up to the start of the first code block."""
    if "```" not in s:
        return ""
    return s[: s.find("```")].strip()


def format_code(code) -> str:
    """Format Python code using Black."""
    try:
        return black.format_str(code, mode=black.FileMode())
    except black.parsing.InvalidInput:  # type: ignore
        return code

def compute_score(solution_str, ground_truth, output_dir, timeout, dense_reward=False, is_eval=False):
    entry = LEADERBOARD[ground_truth]
    worst_score = -10.
    actual_score = -10.

    rand = uuid.uuid4().hex
    run_dir = Path('/workdir/tmp') / f"run_{rand}"

    interpreter = Interpreter(run_dir, timeout=timeout)

    # Remove everything before the first "Assistant:"
    if "Assistant:" in solution_str:
      question_str, solution_str = solution_str.split("Assistant:", 1)
    elif "<|im_start|>assistant" in solution_str:
      question_str, solution_str = solution_str.split("<|im_start|>assistant", 1)
    else:
      print(f"REWARD={worst_score} wrong assistant format")
      return worst_score, actual_score, 1.

    code = extract_code(solution_str)
    nl_text = extract_text_up_to_code(solution_str)

    if not code or not nl_text:
        print(f"REWARD={worst_score} wrong plan + code format")
        return worst_score, actual_score, 1.

    if dense_reward and not is_eval:
      prompt = f"""Please insert print statements in the given python script. The print statements are supposed to reflect the progress of executing a script that solves a Kaggle challenge machine learning benchmark. These print statements will be used to debug the python script. Note that the print statements should include:
      ```
      print("imported packages")
      print("loaded data")
      print("defined model")
      print("training loss:")
      print("trained model")
      print("testing loss:")
      print("predicted test labels")
      ```
      Only insert print statement AFTER an operation is actually performed (e.g., data have actually been loaded). Insert print statements for "training loss:" and "testing loss:" if applicable (i.e., the code actually computes training or testing losses). Output the entire python script after inserting print statements in a single markdown code block (wrapped in ```). Do not modify the original python code, other than inserting print statements.
      \nHere is the python script: ```\n{code}```
      """
      vllm_engine = get_vllm_engine()
          # Prepare input for vLLM
          from verl import DataProto
          from tensordict import TensorDict
          import torch

      tokenizer = vllm_engine.inference_engine.tokenizer
      encoded_inputs = tokenizer(prompt, return_tensors="pt", padding="max_length",
                              max_length=vllm_engine.config.prompt_length,
                              truncation=True)
      batch = TensorDict({
          "input_ids": encoded_inputs["input_ids"].cuda(),
          "attention_mask": encoded_inputs["attention_mask"].cuda(),
          "position_ids": encoded_inputs["attention_mask"].long().cumsum(-1) - 1,
      }, batch_size=1)
      meta_info = {
          "eos_token_id": tokenizer.eos_token_id,
          "pad_token_id": tokenizer.pad_token_id,
          "do_sample": False,
      }
      data_proto = DataProto(batch=batch, meta_info=meta_info)
      with torch.no_grad():
          output_proto = vllm_engine.generate_sequences(data_proto)
      output_ids = output_proto.batch["input_ids"][0]
      prompt_length = encoded_inputs["input_ids"].shape[1]
      response_ids = output_ids[prompt_length:]
      response_text = tokenizer.decode(response_ids, skip_special_tokens=True)
      code = extract_code(response_text)

    res = interpreter.run(code)
    interpreter.cleanup_session()

    final_score = worst_score

    if dense_reward and not is_eval:
      term_out = "\n".join(res.term_out)
      keywords = ["loaded data", "defined model", "training loss", "testing loss", "trained model", "predicted test labels"]
      structured_reward = 0.
      for keyword in keywords:
        if keyword in term_out.lower():
          structured_reward += 0.1
      final_score += structured_reward

      # term_out = "\n".join(res.term_out)
      # with open(f'{output_dir.replace("gs://", "/gcs/")}/generation/termout_{rand}.txt', 'w') as f:
      #   f.write(term_out)
      # keywords = ["loaded data", "defined model", "training loss", "trained model", "predicted test labels"]
      # for keyword in keywords:
      #   if keyword in term_out.lower():
      #     with open(f'{output_dir.replace("gs://", "/gcs/")}/generation/{keyword.replace(" ", "_")}_{rand}.txt', 'w') as f:
      #       f.write(term_out)

    # if dense_reward and res.error_line_percentage:
    #   assert res.error_line_percentage <= 1
    #   final_score = worst_score + res.error_line_percentage * abs(worst_score / 2)
    #   print("REWARD=res.error_line_percentage", res.error_line_percentage)
    #   #final_score = min(final_score, -0.1)

    submission_file = run_dir / "submission.csv"
    if not submission_file.exists():
      print("REWARD=no submission.csv")
    else:
      try:
        cmd = f"python /workdir/mle-bench/mlebench/cli.py grade-sample {submission_file} {ground_truth} --data-dir=/workdir/mle-bench/data/ > {run_dir}/grade_report.json 2>&1"
        result = subprocess.run(cmd, shell=True, capture_output=True, text=True)

        with open(f"{run_dir}/grade_report.json", "r") as f:
          result = f.read()
          result = '{' + result.split('{')[1]
          result = json.loads(result)
          if result["valid_submission"] and result["score"] is not None and result["score"] != 0 and not math.isnan(result["score"]):
            valid_submission_score = 1.
            final_score += valid_submission_score
            actual_score = -result["score"] if entry["is_lower_better"] else result["score"]
            final_score += actual_score
            print("REWARD=VALID", final_score)
            print(f"TIME={res.exec_time}")
          else:
            print(f"REWARD=valid_submission={result['valid_submission']},submission_exists={result['submission_exists']}")
      except Exception as e:
        print("REWARD=error parsing grade.json", e)

    if actual_score != worst_score:
      #with open(f'/workdir/reward_{final_score}_{rand}.txt', 'w') as f:
      with open(f'{output_dir.replace("gs://", "/gcs/")}/generation/reward_{actual_score}_{res.exec_time}.txt', 'w') as f:
        f.write(solution_str)
      # with open(f'{output_dir.replace("gs://", "/gcs/")}/generation/plan_{final_score}_{rand}.txt', 'w') as f:
      #   f.write(solution_str)
    else:
      pass
      #with open(f'{output_dir.replace("gs://", "/gcs/")}/generation/{rand}.txt', 'w') as f:
      #  f.write(solution_str)
      #  f.write(str(res.term_out))

    return final_score, actual_score, res.exec_time
