from datetime import datetime
from typing import Optional
import logging
import pydot
import base64
import json
import re
from pathlib import Path


def get_date(granularity: Optional[str] = "min") -> str:
    """
    get date

    """
    date_time = datetime.now()
    if granularity == "min":
        str_data_time = date_time.strftime("%Y%m%d-%H%M")
    elif granularity == "day":
        str_data_time = date_time.strftime("%Y%m%d")
    else:
        logging.error(f"Undefined timestamp granularity: {granularity}")

    return str_data_time


def convert_time(time, mode):
    logging.info(f"[debug] {time}")

    if mode == "second2string":
        mm = int(time // 60)
        ss = int(time % 60)
        output = f"{mm}:{ss:02}"
    elif mode == "string2second":
        splits = time.split(":")
        if len(splits) == 2:
            mm, ss = splits
        elif len(splits) == 3:
            hh, mm, ss = splits
        else:
            logging.error(f"Incorrect format: {time}")
        output = int(mm) * 60 + int(ss)
    else:
        logging.error(f"Undefined mode: {mode}")
        output = None
    return output


def convert_angle(angle, mode):
    if mode == "string2file":
        match angle:
            case "center":
                output = "C10118_rgb"
            case "top":
                output = "C10115_rgb"
            case "right-bottom":
                output = "C10395_rgb"
            case "right-center":
                output = "C10095_rgb"
            case "right-top":
                output = "C10390_rgb"
            case "left-bottom":
                output = "C10379_rgb"
            case "left-center":
                output = "C10119_rgb"
            case "left-top":
                output = "C10404_rgb"
            case _:
                logging.error(f"Undefined angle: {angle} for {mode}")
                output = None
    elif mode == "file2string":
        pass
    else:
        logging.error(f"Undefined mode: {mode}")
        output = None
    return output


def extract_index(filepath):
    return int(re.search(r"\d+", filepath.stem).group())


def sample_frame(dirpath, start, end, max_frames):
    if start == end:
        start -= 1

    target_filepaths = []
    for filepath in dirpath.glob("*.png"):
        idx = int(filepath.stem)
        if float(start) <= idx <= float(end):
            target_filepaths.append(filepath)

    num_frames = len(target_filepaths)

    # e.g., 70 frames, max 25 => rate: 1 frame per every 3 frames
    if num_frames > max_frames:
        if num_frames % max_frames == 0:
            rate_inverse = num_frames // max_frames
        else:
            rate_inverse = (num_frames // max_frames) + 1
    else:
        rate_inverse = 1
    filepaths_frame_sorted = sorted(target_filepaths, key=extract_index)

    sampled_filepaths = []
    # note: "reversed" to make sure the last frame is included in the input
    for idx, filepath_frame in enumerate(reversed(filepaths_frame_sorted)):
        # change sample rate
        if idx % rate_inverse == 0:
            sampled_filepaths.insert(0, filepath_frame)

    assert len(sampled_filepaths) <= max_frames

    return sampled_filepaths


def get_fps(dirpath, start, end, max_frames):
    if start == end:
        start -= 1

    target_filepaths = []
    for filepath in dirpath.glob("*.png"):
        idx = int(filepath.stem)
        if float(start) <= idx <= float(end):
            target_filepaths.append(filepath)

    num_frames = len(target_filepaths)

    # e.g., 70 frames, max 25 => rate: 1 frame per every 3 frames
    fps = None
    if num_frames > max_frames:
        fps = max_frames / num_frames
    else:
        fps = 1

    return fps


def format_instruction(
    filepath_text, dirpath_instruction_image=None, dirpath_parts_image=None
):
    logging.info("Load instructions ... ")

    toy2instruction = {}
    with open(filepath_text, "r") as f:
        data = json.load(f)

    for annotation in data["examples"]:
        toy_id = annotation["toy_id"]

        G = pydot.Dot(graph_type="digraph")
        action_id2description = {}
        for action in annotation["nodes"]:
            idx = str(action["id"])
            if "checked" in action["data"] and action["data"]["checked"]:
                description = f"{action['data']['label']} w/ screw"
            else:
                description = f"{action['data']['label']}"
            action_id2description[idx] = description
            node = pydot.Node(f"{description}")
            G.add_node(node)

        for edge in annotation["edges"]:
            edge = pydot.Edge(
                action_id2description[str(edge["source"])],
                action_id2description[str(edge["target"])],
            )
            G.add_edge(edge)

        # TODO: change based on resolution
        if dirpath_parts_image:
            filepath_parts_image = dirpath_parts_image / annotation["filepath_image"]
            filepath_parts_image = str(filepath_parts_image).replace(".png", "_360p.png")
        else:
            filepath_parts_image = None

        toy2instruction[toy_id] = {
            "dot": G.to_string().strip(),
            "dag": dirpath_instruction_image / f"{toy_id}.png"
            if dirpath_instruction_image
            else None,
            "parts": Path(filepath_parts_image),
        }

    return toy2instruction


def encode_image(filepath):
    with open(filepath, "rb") as f:
        return base64.b64encode(f.read()).decode("utf-8")


def check_prompt(messages: list):
    # todo: i think reasoning is missing
    output = []
    for message in messages:
        match message["type"]:
            case "message":
                output.append(f"[{message['role']}]")
                if message["role"] == "user":
                    if isinstance(message["content"], str):
                        output.append(f"{message['content']}")
                    elif isinstance(message["content"], list):
                        for one_content in message["content"]:
                            if one_content["type"] == "input_text":
                                output.append(f"{one_content['text']}")
                            elif one_content["type"] == "input_image":
                                output.append("<image omitted for space>")
                            else:
                                logging.error(
                                    f"Undefined content element type: {one_content['type']}"
                                )
                    else:
                        logging.error(
                            f"Undefined content type: {type(message['content'])}"
                        )
                elif message["role"] in ["developer", "system", "assistant"]:
                    output.append(f"{message['content']}")
                else:
                    logging.error(f"Undefined role: {message['role']}")
            case "reasoning":
                output.append("[reasoning]")
                for summary in message["summary"]:
                    output.append(summary["text"])
            case "function_call":
                output.append("[function_call]")
                output.append(f"{message['name']} {message['arguments']}")
                pass
            case "function_call_output":
                output.append("[function_call_output]")
                output.append(f"{message['output']}")
            case _:
                logging.error(f"Undefined {message['type']=}")

    return output
