# viz/plotting.py
"""
Visualization helpers:
  - hit rate vs utilization curve
  - slack heatmap
  - Gantt chart generation for a schedule
  - training loss / reward plots

These utilities accept numpy arrays or python lists and save PNG figures.
"""
import os
import matplotlib.pyplot as plt
import numpy as np
from typing import List, Tuple
from PIL import Image, ImageDraw

def plot_hit_rate_vs_util(utils_list: List[float], hit_rates: List[float], out_path: str = "plots/hit_rate_vs_util.png"):
    os.makedirs(os.path.dirname(out_path) or ".", exist_ok=True)
    plt.figure()
    plt.plot(utils_list, hit_rates, marker="o")
    plt.xlabel("Utilization")
    plt.ylabel("Hit Rate")
    plt.title("Hit Rate vs Utilization")
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(out_path)
    plt.close()

def plot_slack_heatmap(slack_matrix: np.ndarray, out_path: str = "plots/slack_heatmap.png"):
    os.makedirs(os.path.dirname(out_path) or ".", exist_ok=True)
    plt.figure(figsize=(10,4))
    plt.imshow(slack_matrix, aspect="auto", interpolation="nearest", cmap="viridis")
    plt.colorbar(label="Slack")
    plt.xlabel("Time step")
    plt.ylabel("Task ID")
    plt.title("Slack over Time")
    plt.tight_layout()
    plt.savefig(out_path)
    plt.close()

def plot_training_curves(losses: List[float], rewards: List[float], out_dir: str = "plots"):
    os.makedirs(out_dir, exist_ok=True)
    plt.figure(); plt.plot(losses); plt.title("Training Loss"); plt.xlabel("Step"); plt.ylabel("Loss"); plt.tight_layout(); plt.savefig(os.path.join(out_dir, "train_loss.png")); plt.close()
    plt.figure(); plt.plot(rewards); plt.title("Episode Reward"); plt.xlabel("Episode"); plt.ylabel("Reward"); plt.tight_layout(); plt.savefig(os.path.join(out_dir, "train_reward.png")); plt.close()

def gantt_chart_from_schedule(tasks, schedule: List[int], out_path: str = "plots/gantt.png"):
    """
    Draw a simple Gantt chart for one schedule using matplotlib rectangles.
    tasks: list of Task namedtuples
    schedule: list of task_ids per time step (idle = -1)
    """
    import matplotlib.pyplot as plt
    n_tasks = len(tasks)
    fig, ax = plt.subplots(figsize=(12, max(3, n_tasks)))
    colors = plt.cm.tab20.colors
    for tid, task in enumerate(tasks):
        starts = []
        lengths = []
        in_run = False
        start = 0
        for t, a in enumerate(schedule):
            if a == task.id and not in_run:
                in_run = True
                start = t
            elif a != task.id and in_run:
                in_run = False
                starts.append(start)
                lengths.append(t - start)
        if in_run:
            starts.append(start); lengths.append(len(schedule) - start)
        for s, l in zip(starts, lengths):
            ax.broken_barh([(s, l)], (10 * tid, 8), facecolors=[colors[tid % len(colors)]])
    ax.set_ylim(0, 10 * n_tasks)
    ax.set_xlim(0, len(schedule))
    ax.set_yticks([10 * i + 4 for i in range(n_tasks)])
    ax.set_yticklabels([f"Task{t.id}" for t in tasks])
    ax.set_xlabel("Time step")
    ax.set_title("Gantt Chart")
    plt.tight_layout()
    os.makedirs(os.path.dirname(out_path) or ".", exist_ok=True)
    plt.savefig(out_path)
    plt.close()
