import matplotlib.pyplot as plt
import seaborn as sns
from typing import List, Dict
import os
import colorsys
from scripts.generate_prompts import SIZE_PATTERNS

# Enhanced Model Family Color System
MODEL_FAMILIES = {
    # QWEN family - Blue tones (different saturations/lightness)
    "qwen": {
        "base_hue": 0.6,  # Blue
        "models": {
            "qwen-3-1.7b-v4": {"saturation": 0.5, "lightness": 0.7},  # Lightest blue
            "qwen-3-4b-v4": {"saturation": 0.6, "lightness": 0.6},  # Light blue
            "qwen-3-8b-v4": {"saturation": 0.7, "lightness": 0.5},  # Medium blue
            "qwen-3-14b-v4": {"saturation": 0.8, "lightness": 0.4},  # Dark blue
            "qwen-3-32b-v4": {"saturation": 0.9, "lightness": 0.3},  # Darkest blue
        },
    },
    # LLaMA family - Orange/Red tones
    "llama": {
        "base_hue": 0.05,  # Orange-red
        "models": {
            "llama-7b-hf": {"saturation": 0.6, "lightness": 0.6},  # Light orange
            "llama-2-7b-chat-hf": {
                "saturation": 0.7,
                "lightness": 0.5,
            },  # Medium orange
            "llama-3-8b-i": {"saturation": 0.8, "lightness": 0.4},  # Dark orange
            "llama-3.1-8b-i": {"saturation": 0.9, "lightness": 0.3},  # Darkest orange
        },
    },
    # OpenAI family - Green tones
    "openai": {
        "base_hue": 0.33,  # Green
        "models": {
            "gpt-4.1-nano": {"saturation": 0.5, "lightness": 0.6},  # Light green
            "gpt-4.1-mini": {"saturation": 0.6, "lightness": 0.5},  # Medium green
            "o1-mini": {"saturation": 0.7, "lightness": 0.4},  # Dark green
            "o3-mini": {"saturation": 0.9, "lightness": 0.35},  # Very dark green
            "o4-mini": {"saturation": 0.8, "lightness": 0.3},  # Even darker green
            "gpt-5": {"saturation": 0.7, "lightness": 0.2},  # Darkest green
        },
    },
    # Mistral family - Purple tones
    "mistral": {
        "base_hue": 0.75,  # Purple
        "models": {
            "mistral-7b-i-v0.2": {"saturation": 0.8, "lightness": 0.5},  # Medium purple
        },
    },
    # OLMo family - Teal tones
    "olmo": {
        "base_hue": 0.5,  # Teal
        "models": {
            "olmo-2-7b-i": {"saturation": 0.7, "lightness": 0.5},  # Medium teal
        },
    },
    # DeepSeek family - Magenta tones
    "deepseek": {
        "base_hue": 0.83,  # Magenta
        "models": {
            "deepseek-r1-distill-llama-8b": {
                "saturation": 0.8,
                "lightness": 0.4,
            },  # Dark magenta
        },
    },
}

# MODEL NAME MAPPING - Easily editable display names
MODEL_DISPLAY_NAMES = {
    # QWEN models - remove -v4 suffix etc.
    "qwen-3-1.7b-v4": "Qwen3 1.7b",
    "qwen-3-4b-v4": "Qwen3 4b",
    "qwen-3-8b-v4": "Qwen3 8b",
    "qwen-3-14b-v4": "Qwen3 14b",
    "qwen-3-32b-v4": "Qwen3 32b",
    # OpenAI models 
    "gpt-4.1-nano": "GPT 4.1-nano",
    "gpt-4.1-mini": "GPT 4.1-mini",
    "o1-mini": "o1-mini",
    "o3-mini": "o3-mini",
    "o4-mini": "o4-mini",
    "gpt-5": "GPT 5",
    # LLaMA models 
    "llama-7b-hf": "llama-7b-hf",
    "llama-2-7b-chat-hf": "llama-2-7b-chat-hf",
    "llama-3-8b-i": "llama-3-8b-i",
    "llama-3.1-8b-i": "llama-3.1-8b-i",
    # Gemini models 
    "gemini-2.5-pro": "Gemini 2.5-pro",
    # Other models 
    "mistral-7b-i-v0.2": "mistral-7b-i-v0.2",
    "olmo-2-7b-i": "olmo-2-7b-i",
    "deepseek-r1-distill-llama-8b": "deepseek-r1-distill-llama-8b",
}


def get_display_name(model_name: str) -> str:
    """
    Get the display name for a model.

    Parameters:
    - model_name: Internal model name

    Returns:
    - Display name for use in visualizations
    """
    return MODEL_DISPLAY_NAMES.get(model_name, model_name)


def get_internal_name(display_name: str) -> str:
    """
    Get the internal name for a display name (reverse mapping).

    Parameters:
    - display_name: Display name from visualization

    Returns:
    - Internal model name
    """
    # Create reverse mapping
    reverse_mapping = {v: k for k, v in MODEL_DISPLAY_NAMES.items()}
    return reverse_mapping.get(display_name, display_name)


def apply_display_names_to_list(model_names: List[str]) -> List[str]:
    """
    Apply display name mapping to a list of model names.

    Parameters:
    - model_names: List of internal model names

    Returns:
    - List of display names
    """
    return [get_display_name(name) for name in model_names]


def get_size_pattern_order(patterns: List[str]) -> List[str]:
    """
    Sort size patterns according to their definition order in SIZE_PATTERNS.

    Parameters:
    - patterns: List of pattern names to sort

    Returns:
    - List of patterns in SIZE_PATTERNS definition order,
      with any unknown patterns appended alphabetically
    """
    # Get the order from SIZE_PATTERNS definition
    defined_order = list(SIZE_PATTERNS.keys())

    # Separate known and unknown patterns
    known_patterns = [p for p in patterns if p in defined_order]
    unknown_patterns = sorted([p for p in patterns if p not in defined_order])

    # Sort known patterns by their definition order
    ordered_known = [p for p in defined_order if p in known_patterns]

    # Combine ordered known patterns with alphabetically sorted unknown ones
    return ordered_known + unknown_patterns


def hsl_to_hex(h, s, l):
    """Convert HSL to hex color."""
    rgb = colorsys.hls_to_rgb(h, l, s)
    return "#{:02x}{:02x}{:02x}".format(
        int(rgb[0] * 255), int(rgb[1] * 255), int(rgb[2] * 255)
    )


def generate_model_colors():
    """Generate the complete model color mapping."""
    model_colors = {}

    for _, family_info in MODEL_FAMILIES.items():
        base_hue = family_info["base_hue"]

        for model_name, color_params in family_info["models"].items():
            color_hex = hsl_to_hex(
                base_hue, color_params["saturation"], color_params["lightness"]
            )
            model_colors[model_name] = color_hex

    return model_colors


def get_model_family_order(models: List[str]) -> List[str]:
    """
    Sort models according to the order defined in MODEL_FAMILIES.
    Models within each family appear in the order they're defined.
    Families appear in the order they're defined in MODEL_FAMILIES.
    Unknown models appear at the end in alphabetical order.

    Parameters:
    - models: List of model names to sort

    Returns:
    - List of models in the correct order
    """
    ordered_models = []
    unknown_models = []

    # Go through families in order
    for _, family_info in MODEL_FAMILIES.items():
        family_models = list(family_info["models"].keys())

        # Add models from this family that are in our list, in family order
        for family_model in family_models:
            if family_model in models:
                ordered_models.append(family_model)

    # Add any unknown models alphabetically at the end
    for model in models:
        if model not in ordered_models:
            unknown_models.append(model)

    unknown_models.sort()  # Alphabetical for unknown models

    return ordered_models + unknown_models


# Generate the model colors
MODEL_COLORS = generate_model_colors()

# Keep other existing color palettes
QUESTION_TYPE_COLORS = {
    "full_output": "#1f77b4",
    "node_count": "#ff7f0e",
    "edge_count": "#2ca02c",
    "blue_node_count": "#d62728",
    "colored_node_count": "#9467bd",
    "is_connected": "#8c564b",
    "is_tree": "#e377c2",
    "has_cycles": "#7f7f7f",
    "max_degree": "#bcbd22",
    "min_degree": "#17becf",
    "component_count": "#ffbb78",
}

TARGET_COLORS = {
    "input": "#2ca02c",  # Green for input (first)
    "output": "#ff7f0e",  # Orange for output (second)
}


def get_color_palette(items: List[str], palette_type: str = "models") -> Dict[str, str]:
    """Get consistent colors for a list of items."""
    if palette_type == "models":
        base_palette = MODEL_COLORS
        # Use family-based ordering for models
        ordered_items = get_model_family_order(items)
    elif palette_type == "question_types":
        base_palette = QUESTION_TYPE_COLORS
        ordered_items = items
    elif palette_type == "targets":
        base_palette = TARGET_COLORS
        # Ensure input comes first for targets
        ordered_items = []
        if "input" in items:
            ordered_items.append("input")
        for item in items:
            if item != "input":
                ordered_items.append(item)
    else:
        # Generate colors using seaborn palette
        colors = sns.color_palette("husl", len(items))
        return {item: colors[i] for i, item in enumerate(items)}

    # Use predefined colors where available, generate for others
    result = {}

    # Get colors from base palette or generate fallbacks
    available_colors = list(sns.color_palette("husl", len(ordered_items)))
    color_idx = 0

    for item in ordered_items:
        if item in base_palette:
            result[item] = base_palette[item]
        else:
            # Generate fallback color for unknown models
            result[item] = available_colors[color_idx % len(available_colors)]
            color_idx += 1

    return result


def print_model_color_preview():
    """Print a preview of model colors for debugging."""
    print("Model Color Mapping:")
    print("=" * 50)

    for family_name, family_info in MODEL_FAMILIES.items():
        print(f"\n{family_name.upper()} Family:")
        for model_name in family_info["models"].keys():
            color = MODEL_COLORS[model_name]
            display_name = get_display_name(model_name)
            print(f"  {model_name:<25} -> {display_name:<20} {color}")


def setup_plot_style():
    """Set up consistent plot styling."""
    plt.style.use("default")
    sns.set_palette("husl")

    # Set consistent font sizes and styles
    plt.rcParams.update(
        {
            "font.size": 10,
            "axes.titlesize": 12,
            "axes.labelsize": 10,
            "xtick.labelsize": 9,
            "ytick.labelsize": 9,
            "legend.fontsize": 9,
            "figure.titlesize": 14,
            "figure.dpi": 300,
            "savefig.format": "pdf",  # NEW: Default to PDF
            "savefig.bbox": "tight",
            "pdf.fonttype": 42,  # NEW: Embed fonts properly in PDF
        }
    )


def format_accuracy(value: float) -> str:
    """Format accuracy values consistently."""
    return f"{value:.3f}"


def add_sample_counts(ax, data: Dict[str, int], y_offset: float = 0.02):
    """Add sample count annotations to bars."""
    for i, (_, count) in enumerate(data.items()):
        ax.text(
            i, y_offset, f"n={count}", ha="center", va="bottom", fontsize=8, alpha=0.7
        )


def save_plot(fig, filepath: str, title: str = None, no_titles: bool = False):
    """Save plot with consistent settings as PDF."""
    if title and not no_titles:
        fig.suptitle(title, y=0.98, fontsize=14)
    elif no_titles:
        # Remove any existing titles
        fig.suptitle("")
        for ax in fig.get_axes():
            ax.set_title("")

    plt.tight_layout()

    # Create directory if it doesn't exist
    os.makedirs(os.path.dirname(filepath), exist_ok=True)

    # Change extension to .pdf if it's .png
    if filepath.endswith(".png"):
        filepath = filepath.replace(".png", ".pdf")

    # Save as PDF (vector format, so DPI is less relevant)
    fig.savefig(filepath, format="pdf", bbox_inches="tight")
    plt.close(fig)


if __name__ == "__main__":
    # Test the color system
    print_model_color_preview()
