"""
python agent.py --query_mode=one_shot

python ../../analysis/visualize_episode.py --save --path logs/2025-05-12-11-41-07-976902/2025-05-12-11-41-08-603891

python agent.py --query_mode=one_shot --backend=gpt-4.1-nano --deploy --openai_api_key <key>
"""

import base64
import concurrent.futures
import dataclasses
import enum
import json
import logging as py_logging
import re
import shutil
import time
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from datetime import datetime
from io import BytesIO
from pathlib import Path
from typing import List, Tuple

import joblib
import numpy as np
import openai
from absl import app, flags, logging
from google import genai
from google.genai import types
from google.genai.errors import ServerError
from line_profiler import profile
from PIL import Image
from tqdm import tqdm

from knotgym.envs import KnotEnv

# Tasks
flags.DEFINE_integer(
  "task_max_n_crossings", 3, "Max number of crossings in the knot."
)
flags.DEFINE_enum(
  "task", "tie_unknot", ["unknot", "tie_unknot", "eq1"], "Task to perform."
)
flags.DEFINE_integer(
  "speedup_factor",
  1,
  "1 means no speedup, 2 means 2x speedup, and 1/2 the number of frames.",
)
flags.DEFINE_enum(
  "query_mode",
  "stateful",
  ["one_shot", "stateless", "stateful"],
  "Prompt format.",
)

flags.DEFINE_integer(
  "stateful_window",
  -1,
  "Window size for stateful mode. -1 means no window, use all previous states (likely going to break things)",
)

# Backend
flags.DEFINE_boolean("deploy", False, "Whether to deploy the agent.")
flags.DEFINE_string("genai_api_key", "", "API key for Google GenAI.")
flags.DEFINE_string("openai_api_key", "", "API key for OpenAI.")
flags.DEFINE_enum(
  "backend",
  "gemini-2.0-flash",
  ["gemini-2.0-flash", "o4-mini", "gpt-4.1-nano"],
  "Backend for the VLM agent.",
)

# Parallelization
flags.DEFINE_integer("num_workers", 1, "Number of workers to use.")
flags.DEFINE_integer("num_episodes", 1, "Number of episodes to run.")

logger = logging.get_absl_logger()
logger.setLevel(logging.INFO)
logging.get_absl_handler().setFormatter(
  py_logging.Formatter(
    fmt="I%(asctime)s %(filename)s:%(lineno)d] %(message)s",
    datefmt="%m%d %H:%M:%S",
  )
)
FLAGS = flags.FLAGS


class QueryMode(enum.Enum):
  ONE_SHOT = enum.auto()  # output multiple actions in one go
  STATELESS = enum.auto()  # output one action at a time, drop previous states
  STATEFUL = enum.auto()  # output one action at a time, keep previous actions


@dataclass
class State:
  initial: np.ndarray
  target: np.ndarray
  all_obs: List[np.ndarray] = dataclasses.field(default_factory=list)
  all_response: List[str] = dataclasses.field(default_factory=list)
  all_action: List[List[float]] = dataclasses.field(default_factory=list)

  @property
  def step(self):
    return len(self.all_obs)

  def add(self, obs: np.ndarray, response: str, action: List[float]):
    self.all_obs.append(obs)
    self.all_response.append(response)
    self.all_action.append(action)


@dataclass
class Part:
  content: str | Image.Image
  role: str = "user"

  def __post_init__(self):
    assert self.role in ("user", "model"), f"Invalid role: {self.role}"

  def __str__(self):
    """debug"""
    if isinstance(self.content, str):
      return f"{self.role}: {self.content}"
    elif isinstance(self.content, Image.Image):
      return f"{self.role}: {str(self.content)}"
    else:
      raise ValueError(f"Unsupported content type: {type(self.content)}")


@dataclass
class Prompt:
  state: State
  query_mode: QueryMode
  prefix: str = "Output a series of actions to transform the knot from its initial configuration to the goal gauss code."
  eg_dynamics: Path = Path("assets/eg_dynamics_gc.jpeg")
  target_spec: str = "Goal specification: The conversion is considered successful, when the current knot has the same gauss code as the goal knot. When determining the gauss code, always start from the white segment and traverse the rope towards the red segment, record positive for over-cross and negative for under-cross. A visual example is included in the image. An flat loop has gauss code of []."
  action_spec: str = "Action specification: We follow a right-hand coordinate system, centered in the figure. Each action is in the form of [x,y,z,fx,fy,fz] where (x,y,z) are 3D coordinates which will be rounded to the closest rope segment, and (fx,fy,fz) are force vectors to be applied to that rope segment. x,y,z,fx,fy,fz are floating points bounded by [-1, 1]. In the image are three examples of before-and-after pairs of the unit directions. Use them as a reference. You can compose an action, for example, [1.0,1.0,0.0,0.9,0.0,-0.7] means pulling the most upper-right segment with 0.9 unit force in +x direction and 0.7 unit force in -z direction. You can select a segment somewhere in the middle of the rope, for example, let [x,y,z]=[-0.5,0.5,0.0] would be in the center of second quadrant."
  instr_target: str = "Now consider a goal knot of the goal gauss code (what's the gauss code of the following knot?): "
  instr_state: str = "Here is the current knot:"
  instr_series_actions: str = "What are a series of actions that will transform the current knot such that it has the same gauss code as the goal knot? Think step by step, and end your answer in one <answer></answer> block, like: <answer> [-0.8, 0.8, 0.0, 0.9, 0.0, 0.0] \n [0.0, 0.2, 0.0, -0.7, 0.0, 0.0] </answer>. You can include multiple lists of six floats in the block, separated by new lines."
  instr_next_action: str = "What is the next action to take? Think step by step, and end your answer in one <answer></answer> block, like: <answer> [0.0, 0.2, 0.0, -0.7, 0.0, 0.0] </answer>. You should only include one list of six floats in the block."

  @staticmethod
  def _build_general_prompt():
    return [
      Part(Prompt.prefix, "user"),
      Part(Image.open(Prompt.eg_dynamics), "user"),
      Part(Prompt.target_spec, "user"),
      Part(Prompt.action_spec, "user"),
    ]

  def build(self):
    instr = Prompt._build_general_prompt()
    target = self.state.target
    instr += [
      Part(self.instr_target, "user"),
      Part(Image.fromarray(target), "user"),
    ]
    if self.query_mode is QueryMode.ONE_SHOT:
      instr_ = [
        Part(self.instr_state, "user"),
        Part(Image.fromarray(self.state.initial), "user"),
        Part(self.instr_series_actions, "user"),
      ]
    elif self.query_mode is QueryMode.STATELESS:
      s = self.state.initial if self.state.step == 0 else self.state.all_obs[-1]
      instr_ = [
        Part(self.instr_state, "user"),
        Part(Image.fromarray(s), "user"),
        Part(self.instr_next_action, "user"),
      ]
    else:
      assert self.query_mode is QueryMode.STATEFUL

      if (
        self.state.step <= FLAGS.stateful_window or FLAGS.stateful_window == -1
      ):
        instr_ = [
          Part(self.instr_state, "user"),
          Part(Image.fromarray(self.state.initial), "user"),
          Part(self.instr_next_action, "user"),
        ]
        for i in range(self.state.step):
          instr_ += [Part(self.state.all_response[i], "model")]
          instr_ += [
            Part(self.instr_state, "user"),
            Part(Image.fromarray(self.state.all_obs[i]), "user"),
            Part(self.instr_next_action, "user"),
          ]
      else:
        start = self.state.step - FLAGS.stateful_window
        assert start >= 0
        instr_ = []
        for i in range(start, self.state.step):
          instr_ += [Part(self.state.all_response[i], "model")]
          instr_ += [
            Part(self.instr_state, "user"),
            Part(Image.fromarray(self.state.all_obs[i]), "user"),
            Part(self.instr_next_action, "user"),
          ]
        instr_ = instr_[1:]  # remove the first one, which is the first action
    instr += instr_
    return instr


# --- Backend Base ---
class VLMBackend:
  def __init__(self, model_name: str):
    self.model_name = model_name

  def get_response(self, prompt_contents: List[Part]) -> Tuple[str, int]:
    raise NotImplementedError

  def close(self):
    raise NotImplementedError


def retry_on_exception(max_attempts: int = 3):
  def decorator(func):
    def wrapper(*args, **kwargs):
      for attempt in range(max_attempts):
        try:
          return func(*args, **kwargs)
        except ServerError as e:
          if attempt < max_attempts - 1:
            logger.warning(f"Attempt {attempt + 1} failed: {e}")
          else:
            raise

    return wrapper

  return decorator


# --- Google Backend ---
class GoogleBackend(VLMBackend):
  def __init__(self, api_key: str, model_name="gemini-1.5-flash"):
    super().__init__(model_name)
    self.client = genai.Client(api_key=api_key)

  def _proc_part(self, part: Part) -> types.Content | str | Image.Image:
    if part.role == "model":
      assert isinstance(part.content, str), "only support str output from model"
      return types.Content(
        parts=[types.Part.from_text(text=part.content)],
        role="model",
      )
    return part.content

  @retry_on_exception()
  def get_response(self, prompt_contents: List[Part]) -> Tuple[str, int]:
    contents = [self._proc_part(part) for part in prompt_contents]
    start = time.time()
    response = self.client.models.generate_content(
      model=self.model_name, contents=contents
    )
    latency = time.time() - start
    logger.debug(f"{self.model_name} Latency: {latency:.2f}s")
    return response.text, getattr(
      response.usage_metadata, "total_token_count", -1
    )

  def close(self):
    pass


# --- OpenAI Backend ---
class OpenAIBackend(VLMBackend):
  def __init__(self, api_key: str, model_name="o4-mini"):
    super().__init__(model_name)
    self.client = openai.OpenAI(api_key=api_key, max_retries=5)
    self._role_map = {
      "user": "user",
      "model": "assistant",
    }

  @staticmethod
  def _img2b64(image: Image.Image) -> str:
    buf = BytesIO()
    image.save(buf, format="PNG")
    return base64.b64encode(buf.getvalue()).decode()

  def _proc_part(self, part: Part) -> dict:
    if isinstance(part.content, str):
      return {"role": self._role_map[part.role], "content": part.content}
    elif isinstance(part.content, Image.Image):
      img_b64 = self._img2b64(part.content)
      return {
        "role": self._role_map[part.role],
        "content": [
          {
            "type": "image_url",
            "image_url": {"url": f"data:image/png;base64,{img_b64}"},
          }
        ],
      }
    raise ValueError(f"Unsupported content type: {type(part.content)}")

  def get_response(self, prompt_contents: List[Part]) -> Tuple[str, int]:
    messages = [self._proc_part(part) for part in prompt_contents]
    start = time.time()
    response = self.client.chat.completions.create(
      model=self.model_name, messages=messages
    )
    latency = time.time() - start
    logger.info(f"{self.model_name} Latency: {latency:.2f}s")
    return response.choices[0].message.content, getattr(
      response.usage, "total_tokens", -1
    )

  def close(self):
    self.client.close()


# --- Dummy Response ---
def _get_response_dummy(prompt_contents) -> Tuple[str, int]:
  text = """Here's a sequence of actions aimed at transforming the initial knot into the target knot. The strategy involves focusing on manipulating the loop in the top right of the initial knot and moving the lower left strand.

```
<answer>
[0.4, 0.6, 0.0, 0.0, -0.5, 0.0]
[0.6, 0.4, 0.0, -0.5, 0.0, 0.0]
[0.2, -0.7, 0.0, 0.0, 0.5, 0.0]
[-0.7, -0.2, 0.0, 0.5, 0.0, 0.0]
[-0.5, 0.1, 0.0, 0.0, -0.5, 0.0]
</answer>
```"""
  return text, -1


class Vlm:
  """decides which backend to use, parse and log the response"""

  def __init__(self):
    if "gemini" in FLAGS.backend:
      self.backend = GoogleBackend(
        api_key=FLAGS.genai_api_key, model_name=FLAGS.backend
      )
    else:
      self.backend = OpenAIBackend(
        api_key=FLAGS.openai_api_key, model_name=FLAGS.backend
      )

  def _parse(self, text: str) -> List[List[float]]:
    answer_block = re.search(r"<answer>(.*?)</answer>", text, re.DOTALL)
    if not answer_block:
      raise ValueError("No <answer> block found in the response.")
    content = answer_block.group(1)
    matches = re.findall(r"\[([^\]]+)\]", content)
    # float(chr(8722)) is an error
    actions = [
      [float(x.strip().replace(chr(8722), chr(45))) for x in line.split(",")]
      for line in matches
    ]
    logger.debug(f"Actions: {actions}")
    return actions

  def __call__(
    self,
    prompt_contents: List[Part],
    logdir: Path,
    save_suffix: str = "",
  ) -> Tuple[str, List[List[float]]]:
    save_path = logdir / f"{self.backend.model_name}{save_suffix}_prompt.txt"
    with open(save_path, "w") as f:
      f.write("\n".join([str(p) for p in prompt_contents]))

    if FLAGS.deploy:
      text, n_tokens = self.backend.get_response(prompt_contents)
    else:
      text, n_tokens = _get_response_dummy(prompt_contents)

    save_path = logdir / f"{self.backend.model_name}{save_suffix}_resp.txt"
    with open(save_path, "w") as f:
      f.write(text)

    actions = self._parse(text)
    logger.debug(f"Parsed {len(actions)} actions, {n_tokens} tokens")

    save_path = logdir / f"{self.backend.model_name}{save_suffix}.lz4"
    joblib.dump(
      {
        "model_name": self.backend.model_name,
        "prompt": prompt_contents,
        "response": text,
        "actions": actions,
        "n_token": n_tokens,
      },
      save_path,
      compress=("lz4", 3),  # type:ignore
    )

    return text, actions

  def close(self):
    self.backend.close()


@profile
def run_episode(model: Vlm | None, logdir: Path, seed: int = 0):
  if model is None:
    model = Vlm()
  run_logdir = logdir
  logdir = logdir / f"{seed:04d}"  # type:ignore
  logdir.mkdir(parents=True, exist_ok=True)
  logger.info(f"Logdir: {logdir}")
  logger.info(f"Query mode: {FLAGS.query_mode}")
  logger.info(f"Backend: {FLAGS.backend}")
  query_mode = QueryMode[FLAGS.query_mode.upper()]

  warmup_steps = 2
  assert warmup_steps > 0 and warmup_steps
  speedup_factor = FLAGS.speedup_factor
  max_episode_length = 50 // speedup_factor
  frame_skip = 24 * speedup_factor

  env = KnotEnv(
    task=FLAGS.task,
    split="tr",  # always the tr split
    frame_skip=frame_skip,
    task_max_n_states=20,
    task_max_n_crossings=FLAGS.task_max_n_crossings,
    duration=max_episode_length + warmup_steps,
    logdir=logdir,
    render_both=False,
    width=128,  # same as default
    height=128,  # same as default
  )
  obs, _ = env.reset(seed=seed)
  target = env.unwrapped.task_spec.obsg  # type:ignore

  terminated, truncated = False, False
  for i in range(warmup_steps):
    action = np.zeros(env.action_space.shape, dtype=np.float32)  # type:ignore
    obs, reward, terminated, truncated, info = env.step(action)

  state = State(initial=obs, target=target)

  if query_mode is QueryMode.ONE_SHOT:
    prompt = Prompt(state=state, query_mode=query_mode)
    text, actions = model(prompt.build(), logdir, save_suffix="_one_shot")
  else:
    text, actions = "", []

  horizon = (
    len(actions) if query_mode is QueryMode.ONE_SHOT else max_episode_length
  )
  assert horizon <= max_episode_length, f"{horizon=}, {max_episode_length=}"

  for i in range(horizon):
    if query_mode is QueryMode.ONE_SHOT:
      action = actions[i]
    else:
      prompt = Prompt(state=state, query_mode=query_mode)
      text, actions = model(prompt.build(), logdir, save_suffix=f"_{i}")
      action = actions[0]

    obs, reward, terminated, truncated, info = env.step(
      np.array(action, dtype=np.float32)
    )
    logger.info(
      f"Seed {seed} Step {i}: {reward=}, {terminated=}, {truncated=}, {action=}"
    )
    state.add(obs, text, action)
    if terminated or truncated:
      break

  if not (terminated or truncated):
    # not enough actions from the open ended model
    # force write buffer when ending prematurely
    env.env._write_buffer()  # type:ignore
  env.close()
  model.close()

  report = dict(
    seed=seed,
    terminated=terminated,
    truncated=truncated,
    logdir=str(logdir),
    reward=reward,
    length=len(state.all_action),
  )

  with open(run_logdir / "scores.jsonl", "a+") as f:
    f.write(json.dumps(report) + "\n")
  return report


def main(_):
  model = None
  # save configs
  flags_dict = {k: v for k, v in FLAGS.flag_values_dict().items()}
  time_str = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
  folder_name = f"{time_str}_{FLAGS.query_mode}_{FLAGS.backend}_{FLAGS.task}_nx{FLAGS.task_max_n_crossings}"
  if not FLAGS.deploy:
    folder_name += "_dummy"
  logdir = Path("logs") / folder_name

  if logdir.exists():
    logger.warning(f"{logdir} already exists, overwriting...")
    shutil.rmtree(logdir, ignore_errors=True)

  logdir.mkdir(parents=True)
  with open(logdir / "flags.json", "w") as f:
    json.dump(flags_dict, f, indent=2)

  with ThreadPoolExecutor(max_workers=FLAGS.num_workers) as executor:
    futures = [
      executor.submit(run_episode, model, logdir, seed)
      for seed in range(FLAGS.num_episodes)
    ]
    results = []
    for future in tqdm(concurrent.futures.as_completed(futures)):
      try:
        results.append(future.result())
      except Exception as e:
        logger.error(f"Error running episode: {e}")
        continue

  logger.info(f"Done. See scores in {(logdir / 'scores.jsonl').absolute()}")


if __name__ == "__main__":
  app.run(main)
