#from collections import defaultdict
from typing import DefaultDict, Dict, List, Tuple

import matplotlib.pyplot as plt
import numpy as np
import utils
from utils import IDLE_TASK_ID, Task
import torch
import torch.nn as nn
import torch.optim as optim
#from collections import deque
from collections import deque, defaultdict
import random
from utils import generate_random_taskset
from typing import Tuple, List, DefaultDict
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.metrics import precision_recall_curve, roc_curve, auc
from collections import Counter
from typing import List, Tuple
import math
np.random.seed(0)
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 500):
        super().__init__()
        pe = torch.zeros(max_len, d_model)  # (max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)  # (max_len,1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))  # (1, max_len, d_model)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (batch_size, seq_len, d_model)
        seq_len = x.size(1)
        x = x + self.pe[:, :seq_len]
        return x

class TransformerDQN(nn.Module):
    def __init__(self, n_tasks: int, d_model: int, n_heads: int, n_layers: int):
        super().__init__()
        self.seq_len = n_tasks + 1
        self.d_model = d_model

        self.token_emb = nn.Embedding(  
            num_embeddings=50,          
            embedding_dim=d_model
        )
   
        self.cls_token = nn.Parameter(torch.randn(1, 1, d_model))

      
        self.pos_enc = PositionalEncoding(d_model, max_len=self.seq_len)

        # Transformer Encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=n_heads,
            dim_feedforward=4*d_model,
            dropout=0.1,
            activation='relu',
            batch_first=True   
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)

       
        self.to_q = nn.Linear(d_model, 1)

    def forward(self, slack_vec: torch.Tensor) -> torch.Tensor:
            
            slack_clamped = slack_vec.clamp(min=0, max=self.token_emb.num_embeddings - 1)

            # 1) embedding
            x = self.token_emb(slack_clamped)
            
            cls_tokens = self.cls_token.expand(x.size(0), -1, -1)
            x = torch.cat([cls_tokens, x], dim=1)
            
            x = self.pos_enc(x)
            # 4) Transformer
            x = self.transformer(x)
            # 5) to Q
            q_all = self.to_q(x).squeeze(-1)
            return q_all
class DQN(nn.Module):
    def __init__(self, state_dim: int, action_dim: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, action_dim)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)
        
class Job:
    def __init__(self, task: Task, t: int):
        self.task = task

        self.deadline = t + task.deadline
        self.exectime_remaining = task.exectime
        self.time_until_deadline = task.deadline

    def do_job(self):
        self.exectime_remaining -= 1
        self.time_until_deadline -= 1

    def idle_job(self):
        self.time_until_deadline -= 1

    def is_done(self) -> bool:
        return self.exectime_remaining <= 0

    def quantize(self, val: int, quanta_base: int) -> int:
        return max(int(quanta_base * (val // quanta_base)), -4)

    def slack(self, quanta_base: int) -> int:
        if self.is_done():
            return -4
        else:
            return self.quantize(
                self.time_until_deadline - self.exectime_remaining, quanta_base
            )

    def __repr__(self):
        return f"task_id: {self.task.id}, deadline: {self.deadline}, exectime_remaining: {self.exectime_remaining}"


class TaskSchedulingEnvironment:
    def __init__(self, tasks: List[Task], n_quanta: int):
        self.tasks = tasks

        self.max_deadline = max(t.deadline for t in tasks)
        self.quanta_base = self.max_deadline // n_quanta

        self.n_states = len(tasks)
        self.idle_index = len(tasks)
        self.n_actions = len(tasks) + 1

        self.n_iter_episode = utils.get_lcm_period(tasks)
        self.max_reward = sum(self.n_iter_episode / task.period for task in tasks)

        self.reset_state()

    def reset_state(self) -> Tuple[int, ...]:
       
        self.jobs = {t.id: [Job(t, 0)] for t in self.tasks}
        self.state = tuple(self.jobs[t.id][0].slack(self.quanta_base) for t in self.tasks)
        return self.state

    def no_jobs_to_do(self) -> bool:
       
        return all((q and q[0].is_done()) or not q for q in self.jobs.values())

    def sample(self) -> int:
      
        valid = []
        for t in self.tasks:
            queue = self.jobs.get(t.id, [])
            if queue and not queue[0].is_done():
                valid.append(t.id)
        valid.append(self.idle_index)
        return int(np.random.choice(valid))

    def step(self, action: int, i: int) -> Tuple[Tuple[int, ...], int, int]:
       
        reward = 0

        if i != 0:
            for t in self.tasks:
                if (i + 1) % t.period == 0:
                    queue = self.jobs[t.id]
                    if queue and queue[0].is_done():
                        queue.pop(0)
                    queue.append(Job(t, i + 1))

        if action == self.idle_index:
            for queue in self.jobs.values():
                if queue:
                    queue[0].idle_job()
        else:
            for t in self.tasks:
                queue = self.jobs[t.id]
                if not queue:
                    continue
                job = queue[0]
                if t.id == action:
                    job.do_job()
                    if job.is_done():
                        reward += 1
                        queue.pop(0)
                else:
                    job.idle_job()

                if queue and not job.is_done() and job.time_until_deadline < job.exectime_remaining:
                    reward -= 1

        next_state = []
        for t in self.tasks:
            queue = self.jobs[t.id]
            if queue and not queue[0].is_done():
                next_state.append(queue[0].slack(self.quanta_base))
            else:
                next_state.append(self.max_deadline)
        self.state = tuple(next_state)

        return self.state, reward, action



def q_learning(
    n_tasksets: int = 4000,
    n_repeat: int = 5,
    n_quanta: int = 5,
    epsilon: float = 1.0,
    epsilon_decay: float = 0.9995,
    epsilon_min: float = 0.01,
    gamma: float = 0.99,
    lr: float = 1e-3,
    batch_size: int = 64,
    memory_size: int = 10000,
    update_target_every: int = 1000,
    print_status: bool = False,
) -> Tuple[
    nn.Module,                     # now TransformerDQN
    List[List[Task]],
    List[float],
    List[float],
    DefaultDict[float, List[int]],
    List[float],
    List[int],
    np.ndarray,
]:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    memory = deque(maxlen=memory_size)

    losses: List[float] = []
    rewards_per_episode: List[float] = []
    metrics: DefaultDict[float, List[int]] = defaultdict(lambda: [0, 0])
    epsilon_values: List[float] = []
    all_actions_taken: List[int] = []

    training_set = [generate_random_taskset() for _ in range(n_tasksets)]

    policy_net: Optional[nn.Module] = None
    target_net:  Optional[nn.Module] = None
    optimizer:   Optional[optim.Optimizer] = None
    step_counter = 0

    for repeat_i in range(n_repeat):
        for idx, tasks in enumerate(training_set):
            env = TaskSchedulingEnvironment(tasks=tasks, n_quanta=n_quanta)
            state_vec = env.reset_state()            # Tuple[int,...], length = n_tasks
            n_tasks    = len(state_vec)
            action_dim = env.n_actions              

            if policy_net is None:
                policy_net = TransformerDQN(
                    n_tasks=n_tasks,
                    d_model=128,
                    n_heads=4,
                    n_layers=2
                ).to(device)
                target_net = TransformerDQN(
                    n_tasks=n_tasks,
                    d_model=128,
                    n_heads=4,
                    n_layers=2
                ).to(device)
                target_net.load_state_dict(policy_net.state_dict())
                optimizer = optim.Adam(policy_net.parameters(), lr=lr)

            state = torch.tensor([state_vec], dtype=torch.long, device=device)  # (1, n_tasks)
            total_reward = 0.0
            hit_count = 0
            miss_count = 0
            t = idx + repeat_i * n_tasksets

            if print_status and t % 250 == 0:
                print(f"[Episode {t}] epsilon={epsilon:.3f}")

            slack_matrix = np.zeros((n_tasks, env.n_iter_episode), dtype=int)

            for step in range(env.n_iter_episode):
                if env.no_jobs_to_do():
                    action = env.idle_index
                else:
                    if np.random.rand() < epsilon:
                        action = env.sample()
                    else:
                        with torch.no_grad():
                            # policy_net(state) -> (1, n_tasks+1)
                            q_vals = policy_net(state)
                            action = int(torch.argmax(q_vals, dim=1))

                all_actions_taken.append(action)

                next_state_raw, reward, actual_action = env.step(action, step)
                for tid, slack in enumerate(next_state_raw):
                    slack_matrix[tid, step] = slack

                total_reward += reward
                if reward > 0:
                    hit_count += 1
                elif reward < 0:
                    miss_count += 1

                done = env.no_jobs_to_do()
                next_state = torch.tensor([next_state_raw], dtype=torch.long, device=device)

                if actual_action != env.idle_index:
                    memory.append((state, actual_action, reward, next_state, done))
                    if len(memory) >= batch_size:
                        batch = random.sample(memory, batch_size)
                        states, actions, rewards, next_states, dones = zip(*batch)
                        states      = torch.cat(states, dim=0)      # (B, n_tasks)
                        next_states = torch.cat(next_states, dim=0)# (B, n_tasks)
                        actions     = torch.tensor(actions, device=device)
                        rewards     = torch.tensor(rewards, device=device, dtype=torch.float32)
                        dones       = torch.tensor(dones, device=device, dtype=torch.float32)

                        # Q(s,a)
                        q_pred = policy_net(states).gather(1, actions.unsqueeze(1)).squeeze()
                        with torch.no_grad():
                            q_next = target_net(next_states).max(1)[0]
                            q_target = rewards + gamma * q_next * (1 - dones)

                        loss = nn.functional.mse_loss(q_pred, q_target)
                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step()

                        losses.append(loss.item())
                        step_counter += 1
                        if step_counter % update_target_every == 0:
                            target_net.load_state_dict(policy_net.state_dict())

                state = next_state
                if done:
                    break

            epsilon = max(epsilon_min, epsilon * epsilon_decay)
            epsilon_values.append(epsilon)
            rewards_per_episode.append(total_reward / env.max_reward)

            util = round(sum(task.exectime / task.period for task in tasks)*20)/20
            metrics[util][0] += hit_count
            metrics[util][1] += miss_count

    return (
        policy_net,
        training_set,
        losses,
        rewards_per_episode,
        metrics,
        epsilon_values,
        all_actions_taken,
        slack_matrix,
    )



def q_learning_test(
    policy_net: nn.Module,
    test_set: List[List[Task]] = None,
    n_tasksets: int = 10000,
    return_schedules: bool = False,
    n_quanta: int = 5,
) -> Tuple[
    List[Tuple[float, float]],                            # avg_rewards_by_util
    DefaultDict[float, List[Tuple[List[Task], List[int]]]],  # schedules_by_reward
    List[int],                                            # y_true
    List[int],                                            # y_pred
    List[float],                                          # y_scores
    List[Tuple[float, float]],                            # util_q_pairs
    List[Tuple[List[Task], List[int]]]                    # example_schedules
]:
    policy_net.eval()
    device = next(policy_net.parameters()).device

    if test_set is None:
        test_set = [generate_random_taskset() for _ in range(n_tasksets)]

    rewards_by_utilization: DefaultDict[float, List[float]] = defaultdict(list)
    schedules_by_reward: DefaultDict[float, List[Tuple[List[Task], List[int]]]] = defaultdict(list)

    y_true: List[int] = []
    y_pred: List[int] = []
    y_scores: List[float] = []
    util_q_pairs: List[Tuple[float, float]] = []
    example_schedules: List[Tuple[List[Task], List[int]]] = []

    for tasks in test_set:
        env = TaskSchedulingEnvironment(tasks=tasks, n_quanta=n_quanta)
        # state: tuple of ints length n_tasks
        init_state = env.reset_state()
        # 转为 (1, n_tasks) 的 long tensor
        state = torch.tensor([init_state], dtype=torch.long, device=device)

        total_reward = 0.0
        schedule: List[int] = []

        for t in range(env.n_iter_episode):
            with torch.no_grad():
                q_vals = policy_net(state)           # (1, n_tasks+1)
                action = int(torch.argmax(q_vals, dim=1))

            if env.no_jobs_to_do():
                action = env.idle_index

            pred_q = q_vals.max().item()
            y_scores.append(pred_q)

            next_state_raw, reward, actual_action = env.step(action, t)
            total_reward += reward

            y_true.append(1 if reward > 0 else 0)
            y_pred.append(1 if pred_q > 0 else 0)

            util = round(sum(task.exectime / task.period for task in tasks) * 20) / 20
            util_q_pairs.append((util, pred_q))

            state = torch.tensor([next_state_raw], dtype=torch.long, device=device)
            if return_schedules:
                schedule.append(actual_action)

        avg_reward = total_reward / env.max_reward
        rewards_by_utilization[util].append(avg_reward)

        if return_schedules:
            schedules_by_reward[avg_reward].append((tasks, schedule))

        if return_schedules and len(example_schedules) < 5:
            example_schedules.append((tasks, schedule))

    avg_rewards_by_util = sorted(
        (u, float(np.mean(rs))) for u, rs in rewards_by_utilization.items()
    )

    return (
        avg_rewards_by_util,
        schedules_by_reward,
        y_true,
        y_pred,
        y_scores,
        util_q_pairs,
        example_schedules,
    )



print("==============TRAINING MODE==============")

policy_net, \
training_set, \
losses, \
rewards_per_episode, \
metrics, \
epsilon_values, \
all_actions_taken, \
slack_matrix = q_learning(
    n_tasksets=4000,   
    n_repeat=5,       
    n_quanta=5,        
    lr=1e-3,           
    print_status=True  
)

policy_net_unique, \
training_set_unique, \
losses_u, \
rewards_u, \
metrics_u, \
epsilon_values_u, \
all_actions_taken_u, \
slack_matrix_u = q_learning(
    n_tasksets=20000,  
    n_repeat=1,        
    n_quanta=5,        
    lr=1e-3,           
    print_status=True
)

print("==============TESTING MODE==============")

avg_by_util, \
schedules_by_reward, \
y_true, \
y_pred, \
y_scores, \
util_q_pairs, \
example_schedules = q_learning_test(
    policy_net,
    return_schedules=True,   
    n_quanta=5               
)

avg_by_util_unique, \
schedules_by_reward_unique, \
y_true_u, \
y_pred_u, \
y_scores_u, \
util_q_pairs_u, \
example_schedules_u = q_learning_test(
    policy_net_unique,
    test_set=training_set_unique, 
    return_schedules=False,
    n_quanta=5
)

policy_net_random, \
_, \
losses_r, \
rewards_r, \
metrics_r, \
epsilon_values_r, \
all_actions_taken_r, \
slack_matrix_r = q_learning(
    n_tasksets=5000,
    n_repeat=1,
    epsilon=1.0,
    epsilon_decay=1.0,
    epsilon_min=1.0,
    print_status=False,
    n_quanta=5
)

avg_by_util_random, \
schedules_random, \
y_true_r, \
y_pred_r, \
y_scores_r, \
util_q_pairs_r, \
example_schedules_r = q_learning_test(
    policy_net_random,
    return_schedules=False,
    n_quanta=5
)





print("==============PLOTTING==============")

# 1. Training loss curve
plt.figure()
plt.title("Training Loss Curve (Transformer-Enhanced DQN)")
plt.plot(losses, label="Transformer-Enhanced DQN")
plt.xlabel("Training Steps")
plt.ylabel("MSE Loss")
plt.legend()
plt.tight_layout()
plt.savefig("training_loss_curve.png")
plt.close()

# 2. Episode reward curve
plt.figure()
plt.title("Episode Reward Over Training (Transformer-Enhanced DQN)")
plt.plot(rewards_per_episode, label="Transformer-Enhanced DQN")
plt.xlabel("Episode Index")
plt.ylabel("Normalized Reward")
plt.legend()
plt.tight_layout()
plt.savefig("episode_reward_curve.png")
plt.close()

# 3. Deadline hit rate vs. utilization (training metrics)
utils_sorted = sorted(metrics.items())
utils_vals   = [u for u, _ in utils_sorted]
hit_rates    = [
    hits / (hits + misses) if hits + misses > 0 else 0.0
    for hits, misses in (metrics[u] for u in utils_vals)
]
plt.figure()
plt.title("Deadline Hit Rate vs. Utilization (Training, Transformer-Enhanced DQN)")
plt.plot(utils_vals, hit_rates, marker='o', label="Transformer-Enhanced DQN")
plt.xlabel("Utilization")
plt.ylabel("Hit Rate")
plt.legend()
plt.tight_layout()
plt.savefig("hit_rate_vs_utilization.png")
plt.close()

# 4. Precision & Recall for main policy
from sklearn.metrics import precision_score, recall_score
precision = precision_score(y_true, y_pred)
recall    = recall_score(y_true, y_pred)

plt.figure()
plt.title("Precision & Recall (Transformer-Enhanced DQN)")
plt.bar(["Precision", "Recall"], [precision, recall])
plt.ylabel("Score")
plt.tight_layout()
plt.savefig("precision_recall.png")
plt.close()

# 5. Online convergence (reward per episode)
plt.figure()
plt.title("Online Convergence: Reward per Episode (Transformer-Enhanced DQN)")
plt.plot(rewards_per_episode, label="Transformer-Enhanced DQN")
plt.xlabel("Episode")
plt.ylabel("Normalized Reward")
plt.legend()
plt.tight_layout()
plt.savefig("online_convergence.png")
plt.close()

# 6. Hit Rate vs. Utilization (again for clarity)
hit_rate_by_util = sorted(
    (u, h/(h+m))
    for u,(h,m) in metrics.items() if h + m > 0
)
utils_vals2, hit_rates2 = zip(*hit_rate_by_util)
plt.figure()
plt.title("Deadline Hit Rate vs. Utilization (Transformer-Enhanced DQN)")
plt.plot(utils_vals2, hit_rates2, marker="o", label="Transformer-Enhanced DQN")
plt.xlabel("Utilization")
plt.ylabel("Hit Rate")
plt.legend()
plt.tight_layout()
plt.savefig("hit_rate_vs_utilization_repeat.png")
plt.close()

# 7. Normalized Reward vs. Utilization (testing curves)
utils_test,   rewards_test   = zip(*avg_by_util)
utils_unique, rewards_unique = zip(*avg_by_util_unique)
utils_rand,   rewards_rand   = zip(*avg_by_util_random)

plt.figure()
plt.title("Average Normalized Reward vs. Utilization")
plt.plot(utils_test,   rewards_test,   label="Transformer-Enhanced DQN", marker="d")
plt.plot(utils_unique, rewards_unique, label="DQN unique",       marker="s")
plt.plot(utils_rand,   rewards_rand,   label="Random baseline",  marker="^")
plt.xlabel("Utilization")
plt.ylabel("Normalized Reward")
plt.legend()
plt.tight_layout()
plt.savefig("reward_vs_utilization.png")
plt.close()

# 8. Composite hit-rate/reward comparison
tr_curr   = sorted((u, h/(h+m)) for u,(h,m) in metrics.items() if h+m > 0)
tr_repeat = tr_curr
tr_unique = sorted((u, h/(h+m)) for u,(h,m) in metrics_u.items() if h+m > 0)
test_rt   = avg_by_util
rand_rt   = avg_by_util_random

fig, ax = plt.subplots(figsize=(10,4), tight_layout=True)
ax.set_title("Average Hit Rate vs. Utilization")
ax.plot(*zip(*tr_curr),    label="Transformer-Enhanced DQN (training)",         marker="o")
ax.plot(*zip(*tr_repeat),  label="Transformer-Enhanced DQN (repeat×5)",        marker="x")
ax.plot(*zip(*tr_unique),  label="Transformer-Enhanced DQN (unique)",           marker="s")
ax.plot(*zip(*test_rt),    label="Transformer-Enhanced DQN (testing)",          marker="d")
ax.plot(*zip(*rand_rt),    label="Random baseline",                            marker="^")
ax.set_xlabel("Utilization")
ax.set_ylabel("Hit Rate / Reward")
ax.legend(loc="best")
ax.grid(True)
plt.savefig("reward_comparison.png", dpi=300)
plt.close()

# 9. Epsilon decay
plt.figure()
plt.title("Epsilon Decay Over Episodes (Transformer-Enhanced DQN)")
plt.plot(epsilon_values, label="Transformer-Enhanced DQN")
plt.xlabel("Episode")
plt.ylabel("Epsilon")
plt.legend()
plt.tight_layout()
plt.savefig("epsilon_decay.png")
plt.close()

# 10. Smoothed training loss (window=100)
window = 100
smoothed = np.convolve(losses, np.ones(window)/window, mode='valid')
plt.figure()
plt.title("Smoothed Training Loss (window=100, Transformer-Enhanced DQN)")
plt.plot(smoothed, label="Transformer-Enhanced DQN")
plt.xlabel("Training Step")
plt.ylabel("MSE Loss")
plt.legend()
plt.tight_layout()
plt.savefig("smoothed_loss.png")
plt.close()

# 11. Action selection frequency
cnt = Counter(all_actions_taken)
ids, freqs = zip(*sorted(cnt.items()))
plt.figure()
plt.title("Action Selection Frequency (Transformer-Enhanced DQN)")
plt.bar(ids, freqs)
plt.xlabel("Task ID")
plt.ylabel("Count")
plt.tight_layout()
plt.savefig("action_frequency.png")
plt.close()

# 12. Precision–Recall & ROC curves
precisions, recalls, _ = precision_recall_curve(y_true, y_scores)
plt.figure()
plt.title("Precision–Recall Curve (Transformer-Enhanced DQN)")
plt.plot(recalls, precisions, label="Transformer-Enhanced DQN")
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.legend()
plt.tight_layout()
plt.savefig("precision_recall_curve.png")
plt.close()

fpr, tpr, _ = roc_curve(y_true, y_scores)
roc_auc = auc(fpr, tpr)
plt.figure()
plt.title("ROC Curve (Transformer-Enhanced DQN)")
plt.plot(fpr, tpr, label=f"AUC={roc_auc:.2f}")
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.legend()
plt.tight_layout()
plt.savefig("roc_curve.png")
plt.close()

# 13. Slack heatmap
plt.figure(figsize=(12,4))
plt.title("Slack Heatmap Over Time (Transformer-Enhanced DQN)")
plt.imshow(slack_matrix, aspect='auto', interpolation='nearest', cmap='viridis')
plt.colorbar(label="Slack")
plt.xlabel("Time Step")
plt.ylabel("Task ID")
plt.tight_layout()
plt.savefig("slack_heatmap.png")
plt.close()

# 14. Gantt chart of an example schedule
tasks_ex, sched_ex = example_schedules[0]
n_steps = len(sched_ex)
fig, ax = plt.subplots(figsize=(10,2))
for task in tasks_ex:
    starts, lengths = [], []
    in_run = False
    for t, a in enumerate(sched_ex):
        if a == task.id and not in_run:
            in_run = True
            start = t
        elif a != task.id and in_run:
            in_run = False
            starts.append(start)
            lengths.append(t - start)
    if in_run:
        starts.append(start)
        lengths.append(n_steps - start)

    for s, l in zip(starts, lengths):
        ax.broken_barh([(s, l)], (10*task.id, 8), facecolors='tab:blue')

ax.set_ylim(0, 10*len(tasks_ex))
ax.set_xlim(0, n_steps)
ax.set_yticks([10*i + 4 for i in range(len(tasks_ex))])
ax.set_yticklabels([f"Task{t.id}" for t in tasks_ex])
ax.set_xlabel("Time Step")
ax.set_title("Gantt Chart of Example Schedule (Transformer-Enhanced DQN)")
plt.tight_layout()
plt.savefig("gantt_example.png")
plt.close()

# 15. Q-value distribution by utilization
df_q_long = pd.DataFrame(util_q_pairs, columns=['utilization', 'max_q'])
df_q_long.to_csv("qvalue_long.csv", index=False)

plt.figure(figsize=(8,4))
df_q_long.boxplot(column='max_q', by='utilization')
plt.title("Q-Value Distribution by Utilization (Transformer-Enhanced DQN)")
plt.suptitle("")
plt.xlabel("Utilization")
plt.ylabel("Max Q-value")
plt.tight_layout()
plt.savefig("qvalue_distribution.png")
plt.close()


# 导出 CSV
pd.DataFrame({
    "Episode": np.arange(len(epsilon_values)),
    "Epsilon": epsilon_values,
    "Reward": rewards_per_episode[:len(epsilon_values)]
}).to_csv("epsilon_reward.csv", index=False)

pd.DataFrame({
    "Step": np.arange(len(losses)),
    "Loss": losses
}).to_csv("losses.csv", index=False)

df_q_long.groupby("utilization")["max_q"] \
    .median() \
    .reset_index(name="MedianQ") \
    .to_csv("qvalue_by_util.csv", index=False)

print("Saved all figures and CSV data to current directory:")
print(" Figures:")
print("  - training_loss_curve.png")
print("  - episode_reward_curve.png")
print("  - hit_rate_vs_utilization.png")
print("  - precision_recall.png")
print("  - online_convergence.png")
print("  - reward_vs_utilization.png")
print("  - reward_comparison.png")
print("  - epsilon_decay.png")
print("  - smoothed_loss.png")
print("  - action_frequency.png")
print("  - precision_recall_curve.png")
print("  - roc_curve.png")
print("  - slack_heatmap.png")
print("  - gantt_example.png")
print("  - qvalue_distribution.png")
print(" CSV files:")
print("  - epsilon_reward.csv")
print("  - losses.csv")
print("  - qvalue_long.csv")
print("  - qvalue_by_util.csv")



print("==============EXAMPLE SCHEDULES==============")
for idx, (reward, task_scheds) in enumerate(schedules_by_reward.items(), start=1):
    if idx > 5:
        break
    print(f"\nExample Task Set #{idx}")
    tasks, schedule = random.choice(task_scheds)
    schedule = [
        (utils.IDLE_TASK_ID if action == len(tasks) else action)
        for action in schedule
    ]
    util = sum(t.exectime / t.period for t in tasks)
    print(f" Utilization: {util:.2f}, Reward: {reward:.2f}")
    print(f" Schedule (len={len(schedule)}):")
    print("  ", schedule)
    print(" Per-task execution (start–length segments):")
    utils.print_by_task(tasks, schedule)
