import colorsys
from pathlib import Path
from typing import Dict, List, Tuple

import graphviz

from llm_mcts.mcts_algo.mcts_result import MCTSResult
from llm_mcts.mcts_algo.node import Node
from llm_mcts.mcts_scorer.base import MCTSScorer
from llm_mcts.tasks.base import Task
from llm_mcts.models.openai_api import PRICING as OPENAI_PRICING
from llm_mcts.models.claude_api import PRICING as CLAUDE_PRICING
from llm_mcts.models.gemini_api import PRICING as GEMINI_PRICING


def generate_color_variations(
    base_color: str,
    model_names: List[str],
    saturation_range: Tuple[float, float] = (
        0.05,
        0.25,
    ),  # Reduced saturation for lighter colors
    brightness_range: Tuple[float, float] = (
        0.92,
        0.98,
    ),  # Increased brightness for lighter colors
) -> Dict[str, str]:
    """Generate color variations for a group of LLM models based on a base color.

    Args:
        base_color: Base color name ('red', 'green', 'blue', etc.)
        model_names: List of model names to generate colors for
        saturation_range: Tuple of (min, max) saturation values
        brightness_range: Tuple of (min, max) brightness values

    Returns:
        Dictionary mapping model names to hex color codes
    """
    # Define base colors in HSV color space (hue values)
    base_colors = {
        "red": 0,
        "green": 120,
        "blue": 240,
        "purple": 270,
        "orange": 30,
        "teal": 180,
    }

    if base_color not in base_colors:
        raise ValueError(
            f"Unsupported base color: {base_color}. Supported colors are: {list(base_colors.keys())}"
        )

    # Normalize hue to 0-1 range
    hue = base_colors[base_color] / 360

    colors = {}
    n_models = len(model_names)

    for i, model_name in enumerate(model_names):
        # Calculate saturation and brightness based on position in the list
        progress = i / (n_models - 1) if n_models > 1 else 0
        saturation = (
            saturation_range[0] + (saturation_range[1] - saturation_range[0]) * progress
        )
        brightness = (
            brightness_range[1] - (brightness_range[1] - brightness_range[0]) * progress
        )

        # Convert HSV to RGB
        rgb = colorsys.hsv_to_rgb(hue, saturation, brightness)

        # Convert RGB to hex color code
        hex_color = "#{:02x}{:02x}{:02x}".format(
            int(rgb[0] * 255), int(rgb[1] * 255), int(rgb[2] * 255)
        )

        colors[model_name] = hex_color

    return colors


# Generate colors for each LLM family
COLOR_DICT = {}
COLOR_DICT.update(generate_color_variations("green", OPENAI_PRICING.keys()))
COLOR_DICT.update(generate_color_variations("red", CLAUDE_PRICING.keys()))
COLOR_DICT.update(generate_color_variations("blue", GEMINI_PRICING.keys()))


def build_graph_view(
    root: Node,
    scorer: MCTSScorer,
    task: Task,
    dot: graphviz.Digraph,
    label_box: List[int],
    labels: Dict[Node, str],
    mcts_result: MCTSResult,
) -> graphviz.Digraph:
    label = str(label_box[0])
    if hasattr(root, "serial_number"):
        label = str(root.serial_number)

    score = scorer.get_score(node=root) if root.last_action else 0

    # Temprarily remove verifier logic here since it is arc specific
    # decimal part is verifier score, integer part is number of correct answers
    # verifier_score = score - int(score)
    # score = int(score)

    color = "black" if score == 0 else "blue"

    displayed_text = f"{label}\n{root.last_action}; {score:.2f}"
    if root.eval_results is not None:
        displayed_text += f"/{len(root.eval_results)}"
    if root.last_action == "transform" or root.last_action == "answer":
        eval_results, test_score = task.evaluate_on_test(
            root.next_prompt.get_last_generation_result()
        )
        if len(eval_results) > 0:
            displayed_text += f"; test score {test_score}"
            displayed_text += f"/{len(eval_results)}"
            if test_score == len(eval_results):
                color = "green"

        # if verifier_score > 0:
        #     displayed_text += f"\nverifier score: {verifier_score:.3f}/1.0"

        # Add LLM name to display if available
        if hasattr(root, "llm_name"):
            # if anthropic model, remove "us.anthropic." prefix because it is too long
            display_llm_name = root.llm_name
            if "us.anthropic." in root.llm_name:
                display_llm_name = display_llm_name.replace("us.anthropic.", "")
            elif "anthropic." in root.llm_name:
                display_llm_name = display_llm_name.replace("anthropic.", "")
            displayed_text += f"\nLLM: {display_llm_name}"

    # Add node with color based on LLM
    node_style = {
        "color": color,
        "style": "filled",
        "fontcolor": "black",  # Fixed black text color
        "penwidth": "5.0",
    }

    # Set background color if LLM name is available
    if hasattr(root, "llm_name") and root.llm_name in COLOR_DICT:
        node_style["fillcolor"] = COLOR_DICT[root.llm_name]
    else:
        node_style["fillcolor"] = "white"

    dot.node(label, displayed_text, **node_style)
    labels[root] = label
    label_box[0] += 1
    if root.parent:
        dot.edge(labels[root.parent], labels[root])
    for child in root.children:
        build_graph_view(child, scorer, task, dot, label_box, labels, mcts_result)
    return dot


def render_mcts_graph(
    mcts_result: MCTSResult,
    scorer: MCTSScorer,
    task: Task,
    fpath: Path,
    view: bool = False,
) -> graphviz.Digraph:
    dot = graphviz.Digraph()
    dot = build_graph_view(
        mcts_result.root, scorer, task, dot, [0], dict(), mcts_result=mcts_result
    )
    dot.render(fpath, view=view)
    return dot
