import numpy as np
import argparse
import re
from pathlib import Path

from scipy.spatial.transform import Rotation as R
from model import *
from vima_utils import *


def action_to_text(obs_img: np.ndarray, action: dict, spatula: bool) -> str:
    act_pos_start = action["pose0_position"]
    act_pos_end = action["pose1_position"]
    act_rot_start = action["pose0_rotation"]
    act_rot_end = action["pose1_rotation"]
    # Convert to Euler angles (in radians)
    e1 = R.from_quat(act_rot_start).as_euler("xyz", degrees=True)
    e2 = R.from_quat(act_rot_end).as_euler("xyz", degrees=True)

    de = int(e2[-1] - e1[-1])

    px, py = pos2pix_front(act_pos_start[0], act_pos_start[1])
    tx, ty = pos2pix_front(act_pos_end[0], act_pos_end[1])

    obj_text = "object"

    assert len(obs_img.shape) == 3
    w = obs_img.shape[2]
    h = obs_img.shape[1]
    if spatula:
        return f"Sweep the {obj_text} at <b>({px / w :.3f}, {py / h: .3f})</b>, rotate <r>[{-de}]</r> degrees, and stop at <b>({tx / w :.3f}, {ty / h :.3f})</b>."
    else:
        return f"Pick up the {obj_text} at <b>({px / w :.3f}, {py / h: .3f})</b>, rotate <r>[{-de}]</r> degrees, and drop it at <b>({tx / w :.3f}, {ty / h :.3f})</b>."


def query_bc(
    tokenizer,
    model,
    image_processor,
    prompt,
    prompt_mode,
    prompt_assets,
    action_hist,
    obs,
    detector,
    prompt_id,
):
    prepared_prompt, image_list = prepare_prompt(
        tokenizer,
        model,
        image_processor,
        prompt,
        mode=prompt_mode,
        prompt_assets=prompt_assets,
        spatula=obs["ee"] > 0,
        detector=detector,
        uid=prompt_id,
    )

    # tag for action history
    if "h" in prompt_mode and len(action_hist):
        # append history action to the prompt
        prepared_prompt += "\nYou have finished: "
        prepared_prompt += "\n".join(
            [
                f"Step {acti + 1}: "
                + action_to_text(obs["rgb"]["front"], acth, spatula=obs["ee"] > 0)
                for acti, acth in enumerate(action_hist)
            ]
        )

    image_list.append(obs["rgb"]["front"].copy())
    ans, _ = model_generation(
        tokenizer, model, image_processor, image_list, prepared_prompt
    )

    # parse positions and rotations
    str_actions = re.findall(r"\(.+?,.+?\)", ans)
    str_rotation = re.findall(r"\[(-*\d*)\]", ans)
    m_len = min(len(str_actions) // 2, len(str_rotation))
    if "s" in prompt_mode:
        m_len = min(m_len, 1)
    return (
        [
            str_actions[idx * 2 : idx * 2 + 2] + str_rotation[idx : idx + 1]
            for idx in range(m_len)
        ],
        prepared_prompt,
        ans,
    )


def parse_bc(pick_place_point):
    p0 = parse_coor(pick_place_point[0])
    p1 = parse_coor(pick_place_point[1])

    action = {
        "pose0_position": np.array(pix2pos_front(*p0), dtype=np.float32),
        "pose1_position": np.array(pix2pos_front(*p1), dtype=np.float32),
    }
    rotation = -float(pick_place_point[2])
    rot_quat = np.array(
        R.from_euler("z", rotation, degrees=True).as_quat(), dtype=np.float32
    )
    action["pose1_rotation"] = rot_quat
    return action


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("filename")
    parser.add_argument("--model-path", type=str)
    parser.add_argument("--output-path", type=str, default="../results/")
    parser.add_argument("--prompt-mode", type=str, default="hs")
    parser.add_argument("--prompt-id", type=int, default=-1)
    parser.add_argument("--seed", type=int, default=200000)
    parser.add_argument("--num-env", type=int, default=20)
    parser.add_argument("--max-length", type=int, default=8)
    parser.add_argument("--partition", type=str, default="")
    parser.add_argument("--detector", type=str, default="")
    parser.add_argument("--detector-thre", type=float, default=0.6)

    args = parser.parse_args()
    assert "RT" not in args.model_path
    assert Path(args.model_path).is_dir()
    eval_episode(args, query_bc, parse_bc)
