#from collections import defaultdict
from typing import DefaultDict, Dict, List, Tuple

import matplotlib.pyplot as plt
import numpy as np
import utils
from utils import IDLE_TASK_ID, Task
import torch
import torch.nn as nn
import torch.optim as optim
#from collections import deque
from collections import deque, defaultdict
import random
from utils import generate_random_taskset
from typing import Tuple, List, DefaultDict
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.metrics import precision_recall_curve, roc_curve, auc
from collections import Counter
from typing import List, Tuple

np.random.seed(0)

class DQN(nn.Module):
    def __init__(self, state_dim: int, action_dim: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, action_dim)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)
        
class Job:
    def __init__(self, task: Task, t: int):
        self.task = task

        self.deadline = t + task.deadline
        self.exectime_remaining = task.exectime
        self.time_until_deadline = task.deadline

    def do_job(self):
        self.exectime_remaining -= 1
        self.time_until_deadline -= 1

    def idle_job(self):
        self.time_until_deadline -= 1

    def is_done(self) -> bool:
        return self.exectime_remaining <= 0

    def quantize(self, val: int, quanta_base: int) -> int:
        return max(int(quanta_base * (val // quanta_base)), -4)

    def slack(self, quanta_base: int) -> int:
        if self.is_done():
            return -4
        else:
            return self.quantize(
                self.time_until_deadline - self.exectime_remaining, quanta_base
            )

    def __repr__(self):
        return f"task_id: {self.task.id}, deadline: {self.deadline}, exectime_remaining: {self.exectime_remaining}"


class TaskSchedulingEnvironment:
    def __init__(self, tasks: List[Task], n_quanta: int):
     
        self.tasks = tasks

      
        self.max_deadline = max(t.deadline for t in tasks)
        self.quanta_base = self.max_deadline // n_quanta

    
        self.n_states = len(tasks)
      
        self.idle_index = len(tasks)
        self.n_actions = len(tasks) + 1

      
        self.n_iter_episode = utils.get_lcm_period(tasks)
        self.max_reward = sum(self.n_iter_episode / task.period for task in tasks)

    
        self.reset_state()

    def reset_state(self) -> Tuple[int, ...]:
 
        self.jobs = {t.id: [Job(t, 0)] for t in self.tasks}
   
        self.state = tuple(self.jobs[t.id][0].slack(self.quanta_base) for t in self.tasks)
        return self.state

    def no_jobs_to_do(self) -> bool:
    
     
        return all((q and q[0].is_done()) or not q for q in self.jobs.values())

    def sample(self) -> int:
      
        valid = []
        for t in self.tasks:
            queue = self.jobs.get(t.id, [])
            if queue and not queue[0].is_done():
                valid.append(t.id)
       
        valid.append(self.idle_index)
        return int(np.random.choice(valid))

    def step(self, action: int, i: int) -> Tuple[Tuple[int, ...], int, int]:
   
        reward = 0

        if i != 0:
            for t in self.tasks:
                if (i + 1) % t.period == 0:
                    queue = self.jobs[t.id]
                  
                    if queue and queue[0].is_done():
                        queue.pop(0)
                    
                    queue.append(Job(t, i + 1))

        if action == self.idle_index:
            for queue in self.jobs.values():
                if queue:
                    queue[0].idle_job()
        else:
         
            for t in self.tasks:
                queue = self.jobs[t.id]
                if not queue:
                    continue
                job = queue[0]
                if t.id == action:
                    job.do_job()
                    if job.is_done():
                        reward += 1
                        queue.pop(0)
                else:
                    job.idle_job()

                if queue and not job.is_done() and job.time_until_deadline < job.exectime_remaining:
                    reward -= 1

        next_state = []
        for t in self.tasks:
            queue = self.jobs[t.id]
            if queue and not queue[0].is_done():
                next_state.append(queue[0].slack(self.quanta_base))
            else:
                next_state.append(self.max_deadline)
        self.state = tuple(next_state)

        return self.state, reward, action



def q_learning(
    n_tasksets: int = 4000,
    n_repeat: int = 5,
    n_quanta: int = 5,
    epsilon: float = 1.0,
    epsilon_decay: float = 0.9995,
    epsilon_min: float = 0.01,
    gamma: float = 0.99,
    lr: float = 1e-3,
    batch_size: int = 64,
    memory_size: int = 10000,
    update_target_every: int = 1000,
    print_status: bool = False,
) -> Tuple[
    DQN,
    List[List[Task]],
    List[float],
    List[float],
    DefaultDict[float, List[int]],
    List[float],
    List[int],
    np.ndarray,
]:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    memory = deque(maxlen=memory_size)

    losses: List[float] = []
    rewards_per_episode: List[float] = []
    metrics: DefaultDict[float, List[int]] = defaultdict(lambda: [0, 0])

    epsilon_values: List[float] = []
    all_actions_taken: List[int] = []

    training_set = [generate_random_taskset() for _ in range(n_tasksets)]

    policy_net: Optional[DQN] = None
    target_net:  Optional[DQN] = None
    optimizer:   Optional[optim.Optimizer] = None
    step_counter = 0

    for repeat_i in range(n_repeat):
        for idx, tasks in enumerate(training_set):
            env = TaskSchedulingEnvironment(tasks=tasks, n_quanta=n_quanta)
            state_vec = env.reset_state()
            state_dim  = len(state_vec)
            action_dim = env.n_actions  

            if policy_net is None:
                policy_net = DQN(state_dim, action_dim).to(device)
                target_net = DQN(state_dim, action_dim).to(device)
                target_net.load_state_dict(policy_net.state_dict())
                optimizer = optim.Adam(policy_net.parameters(), lr=lr)

            state = torch.tensor(state_vec, dtype=torch.float32, device=device)
            total_reward = 0.0
            hit_count = 0
            miss_count = 0
            t = idx + repeat_i * n_tasksets

            if print_status and t % 250 == 0:
                print(f"[Episode {t}] epsilon={epsilon:.3f}")

            slack_matrix = np.zeros((len(tasks), env.n_iter_episode), dtype=int)

            for step in range(env.n_iter_episode):
                if env.no_jobs_to_do():
                    action = env.idle_index
                else:
                    if np.random.rand() < epsilon:
                        action = env.sample()
                    else:
                        with torch.no_grad():
                            action = int(torch.argmax(policy_net(state)))

                all_actions_taken.append(action)

                next_state_raw, reward, actual_action = env.step(action, step)

                for tid, slack in enumerate(next_state_raw):
                    slack_matrix[tid, step] = slack

                total_reward += reward
                if reward > 0:
                    hit_count += 1
                elif reward < 0:
                    miss_count += 1

                done = env.no_jobs_to_do()
                next_state = torch.tensor(next_state_raw, dtype=torch.float32, device=device)

                if actual_action != env.idle_index:
                    memory.append((state, actual_action, reward, next_state, done))
                    if len(memory) >= batch_size:
                        batch = random.sample(memory, batch_size)
                        states, actions, rewards, next_states, dones = zip(*batch)
                        states      = torch.stack(states)
                        next_states = torch.stack(next_states)
                        actions     = torch.tensor(actions, device=device)
                        rewards     = torch.tensor(rewards, device=device, dtype=torch.float32)
                        dones       = torch.tensor(dones, device=device, dtype=torch.float32)

                        q_pred = policy_net(states).gather(1, actions.unsqueeze(1)).squeeze()
                        with torch.no_grad():
                            q_next   = target_net(next_states).max(1)[0]
                            q_target = rewards + gamma * q_next * (1 - dones)

                        loss = nn.functional.mse_loss(q_pred, q_target)
                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step()

                        losses.append(loss.item())
                        step_counter += 1
                        if step_counter % update_target_every == 0:
                            target_net.load_state_dict(policy_net.state_dict())

                state = next_state
                if done:
                    break

            epsilon = max(epsilon_min, epsilon * epsilon_decay)
            epsilon_values.append(epsilon)
            rewards_per_episode.append(total_reward / env.max_reward)

            util = round(sum(task.exectime / task.period for task in tasks) * 20) / 20
            metrics[util][0] += hit_count
            metrics[util][1] += miss_count

    return (
        policy_net,
        training_set,
        losses,
        rewards_per_episode,
        metrics,
        epsilon_values,
        all_actions_taken,
        slack_matrix,
    )



def q_learning_test(
    policy_net: DQN,
    test_set: List[List[Task]] = None,
    n_tasksets: int = 10000,
    return_schedules: bool = False,
    n_quanta: int = 5,
) -> Tuple[
    List[Tuple[float, float]],
    DefaultDict[float, List[Tuple[List[Task], List[int]]]],
    List[int],
    List[int],
    List[float],
    List[Tuple[float, float]],
    List[Tuple[List[Task], List[int]]],
]:
    policy_net.eval()
    device = next(policy_net.parameters()).device

    if test_set is None:
        test_set = [generate_random_taskset() for _ in range(n_tasksets)]

    rewards_by_utilization: DefaultDict[float, List[float]] = defaultdict(list)
    schedules_by_reward: DefaultDict[float, List[Tuple[List[Task], List[int]]]] = defaultdict(list)

    y_true: List[int] = []
    y_pred: List[int] = []
    y_scores: List[float] = []
    util_q_pairs: List[Tuple[float, float]] = []
    example_schedules: List[Tuple[List[Task], List[int]]] = []

    for tasks in test_set:
        env = TaskSchedulingEnvironment(tasks=tasks, n_quanta=n_quanta)
        state = torch.tensor(env.reset_state(), dtype=torch.float32, device=device)
        total_reward = 0.0
        schedule: List[int] = []

        for t in range(env.n_iter_episode):
            with torch.no_grad():
                q_vals = policy_net(state)
                action = int(torch.argmax(q_vals))
            if env.no_jobs_to_do():
                action = env.idle_index

            pred_q = q_vals.max().item()
            y_scores.append(pred_q)

            next_state_raw, reward, actual_action = env.step(action, t)
            total_reward += reward

            y_true.append(1 if reward > 0 else 0)
            y_pred.append(1 if pred_q > 0 else 0)

            util = round(sum(task.exectime / task.period for task in tasks) * 20) / 20
            util_q_pairs.append((util, pred_q))

            if return_schedules:
                if actual_action == env.idle_index:
                    schedule.append(IDLE_TASK_ID)
                else:
                    schedule.append(actual_action)

            state = torch.tensor(next_state_raw, dtype=torch.float32, device=device)

        avg_reward = total_reward / env.max_reward
        rewards_by_utilization[util].append(avg_reward)

        if return_schedules:
            schedules_by_reward[avg_reward].append((tasks, schedule))

        if return_schedules and len(example_schedules) < 5:
            example_schedules.append((tasks, schedule))

    avg_rewards_by_util = sorted(
        (u, float(np.mean(rs))) for u, rs in rewards_by_utilization.items()
    )

    return (
        avg_rewards_by_util,
        schedules_by_reward,
        y_true,
        y_pred,
        y_scores,
        util_q_pairs,
        example_schedules,
    )
if __name__ == '__main__':
    print("==============TRAINING MODE==============")

    policy_net, \
    training_set, \
    losses, \
    rewards_per_episode, \
    metrics, \
    epsilon_values, \
    all_actions_taken, \
    slack_matrix = q_learning(
        n_tasksets=4000,
        n_repeat=5,
        print_status=True
    )

    # 2) train on a unique larger taskset
    policy_net_unique, \
    training_set_unique, \
    losses_u, \
    rewards_u, \
    metrics_u, \
    epsilon_values_u, \
    all_actions_taken_u, \
    slack_matrix_u = q_learning(
        n_tasksets=20000,
        n_repeat=1,
        print_status=True
    )

    print("==============TESTING MODE==============")

    # 3) evaluate the first (main) policy, collect schedules + y_true/y_pred
    avg_by_util, \
    schedules_by_reward, \
    y_true, \
    y_pred, \
    y_scores, \
    util_q_pairs, \
    example_schedules = q_learning_test(
        policy_net,
        return_schedules=True
    )

    # 4) evaluate the unique‐trained policy 
    avg_by_util_unique, \
    schedules_by_reward_unique, \
    y_true_u, \
    y_pred_u, \
    y_scores_u, \
    util_q_pairs_u, \
    example_schedules_u = q_learning_test(
        policy_net_unique,
        test_set=training_set_unique,
        return_schedules=False
    )

    # 5) random baseline policy (pure exploration)
    policy_net_random, \
    _, \
    losses_r, \
    rewards_r, \
    metrics_r, \
    epsilon_values_r, \
    all_actions_taken_r, \
    slack_matrix_r = q_learning(
        n_tasksets=5000,
        n_repeat=1,
        epsilon=1.0,
        epsilon_decay=1.0,
        epsilon_min=1.0,
        print_status=False
    )

    avg_by_util_random, \
    schedules_random, \
    y_true_r, \
    y_pred_r, \
    y_scores_r, \
    util_q_pairs_r, \
    example_schedules_r = q_learning_test(
        policy_net_random,
        return_schedules=False
    )




    print("==============PLOTTING==============")

    # Training loss curve
    plt.figure()
    plt.title("Training Loss Curve")
    plt.plot(losses)
    plt.xlabel("Training Steps")
    plt.ylabel("MSE Loss")
    plt.tight_layout()
    plt.savefig("training_loss_curve.png")
    plt.close()

    # Episode reward curve
    plt.figure()
    plt.title("Episode Reward Over Training")
    plt.plot(rewards_per_episode)
    plt.xlabel("Episode Index")
    plt.ylabel("Normalized Reward")
    plt.tight_layout()
    plt.savefig("episode_reward_curve.png")
    plt.close()

    # Deadline hit rate vs. utilization (training metrics)
    utils_sorted = sorted(metrics.items())
    utils_vals = [u for u, _ in utils_sorted]
    hit_rates = [
        hits / (hits + misses) if hits + misses > 0 else 0.0
        for hits, misses in [metrics[u] for u in utils_vals]
    ]
    plt.figure()
    plt.title("Deadline Hit Rate vs. Utilization")
    plt.plot(utils_vals, hit_rates, marker='o')
    plt.xlabel("Utilization")
    plt.ylabel("Hit Rate")
    plt.tight_layout()
    plt.savefig("hit_rate_vs_utilization.png")
    plt.close()

    # Precision & Recall for main policy
    from sklearn.metrics import precision_score, recall_score
    precision = precision_score(y_true, y_pred)
    recall    = recall_score(y_true, y_pred)

    plt.figure()
    plt.title("Precision & Recall of DQN Predictions")
    plt.bar(["Precision", "Recall"], [precision, recall])
    plt.ylabel("Score")
    plt.tight_layout()
    plt.savefig("precision_recall.png")
    plt.close()
    ##########################3
    plt.figure()
    plt.title("Online Convergence: Reward per Episode")
    plt.plot(rewards_per_episode)
    plt.xlabel("Episode")
    plt.ylabel("Normalized Reward")
    plt.tight_layout()
    plt.savefig("online_convergence.png")
    plt.close()

    hit_rate_by_util = sorted(
        (u, hits / (hits + misses))
        for u, (hits, misses) in metrics.items()
    )
    utils_vals, hit_rates = zip(*hit_rate_by_util)

    plt.figure()
    plt.title("Deadline Hit Rate vs. Utilization")
    plt.plot(utils_vals, hit_rates, marker="o")
    plt.xlabel("Utilization")
    plt.ylabel("Hit Rate")
    plt.tight_layout()
    plt.savefig("hit_rate_vs_utilization.png")
    plt.close()

    #    avg_by_util = [(util, avg_reward), …]
    utils_test, rewards_test = zip(*avg_by_util)
    utils_unique, rewards_unique = zip(*avg_by_util_unique)
    utils_rand,   rewards_rand   = zip(*avg_by_util_random)

    plt.figure()
    plt.title("Average Normalized Reward vs. Utilization")
    plt.plot(utils_test,   rewards_test,   label="DQN main",    marker="d")
    plt.plot(utils_unique, rewards_unique, label="DQN unique",  marker="s")
    plt.plot(utils_rand,   rewards_rand,   label="Random base", marker="^")
    plt.xlabel("Utilization")
    plt.ylabel("Normalized Reward")
    plt.legend()
    plt.tight_layout()
    plt.savefig("reward_vs_utilization.png")
    plt.close()
    ##############################

    tr_rewards_by_utilization = sorted(
        (u, hits/(hits+misses))
        for u,(hits,misses) in metrics.items()
    )
    trained_rewards_by_utilization = sorted(
        (u, hits/(hits+misses))
        for u,(hits,misses) in metrics.items()
    )
    trained_rewards_by_utilization_unique_set = sorted(
        (u, hits/(hits+misses))
        for u,(hits,misses) in metrics_u.items()
    )
    test_rewards_by_utilization = avg_by_util
    rand_rewards_by_utilization = avg_by_util_random

    fig, ax = plt.subplots(figsize=(10,4), tight_layout=True)
    ax.set_title("Average Hit Rate vs. Utilization")
    ax.plot(*zip(*tr_rewards_by_utilization),    label="During-training",         marker="o")
    ax.plot(*zip(*trained_rewards_by_utilization),label="Post-training (repeat×5)", marker="x")
    ax.plot(*zip(*trained_rewards_by_utilization_unique_set),
                                                  label="Post-training (unique)",  marker="s")
    ax.plot(*zip(*test_rewards_by_utilization),   label="Testing (reward)",         marker="d")
    ax.plot(*zip(*rand_rewards_by_utilization),   label="Random (reward)",          marker="^")
    ax.set_xlabel("Utilization")
    ax.set_ylabel("Hit Rate / Reward")
    ax.legend(loc="best")
    ax.grid(True)
    plt.savefig("reward_comparison.png", dpi=300)
    plt.close()

    plt.figure()
    plt.plot(epsilon_values)
    plt.title("Epsilon Decay Over Episodes")
    plt.xlabel("Episode")
    plt.ylabel("Epsilon")
    plt.tight_layout()
    plt.savefig("epsilon_decay.png")
    plt.close()

    window = 100
    smoothed = np.convolve(losses, np.ones(window)/window, mode='valid')
    plt.figure()
    plt.plot(smoothed)
    plt.title("Smoothed Training Loss (window=100)")
    plt.xlabel("Training Step")
    plt.ylabel("MSE Loss")
    plt.tight_layout()
    plt.savefig("smoothed_loss.png")
    plt.close()


    cnt = Counter(all_actions_taken)
    ids, freqs = zip(*sorted(cnt.items()))
    plt.figure()
    plt.bar(ids, freqs)
    plt.title("Action Selection Frequency")
    plt.xlabel("Task ID")
    plt.ylabel("Count")
    plt.tight_layout()
    plt.savefig("action_frequency.png")
    plt.close()

    # 4. Precision–Recall & ROC
    precisions, recalls, _ = precision_recall_curve(y_true, y_scores)
    plt.figure()
    plt.plot(recalls, precisions)
    plt.title("Precision–Recall Curve")
    plt.xlabel("Recall")
    plt.ylabel("Precision")
    plt.tight_layout()
    plt.savefig("precision_recall_curve.png")
    plt.close()

    fpr, tpr, _ = roc_curve(y_true, y_scores)
    roc_auc = auc(fpr, tpr)
    plt.figure()
    plt.plot(fpr, tpr, label=f"AUC={roc_auc:.2f}")
    plt.title("ROC Curve")
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.legend()
    plt.tight_layout()
    plt.savefig("roc_curve.png")
    plt.close()

   
    # slack_matrix: shape (n_tasks, n_iter_episode)
    plt.figure(figsize=(12,4))
    plt.imshow(slack_matrix, aspect='auto', interpolation='nearest', cmap='viridis')
    plt.colorbar(label="Slack")
    plt.title("Slack Heatmap Over Time")
    plt.xlabel("Time Step")
    plt.ylabel("Task ID")
    plt.tight_layout()
    plt.savefig("slack_heatmap.png")
    plt.close()


    tasks, sched = example_schedules[0]
    n_steps = len(sched)
    fig, ax = plt.subplots(figsize=(10,2))
    for task in tasks:
   
        starts = []
        lengths = []
        in_run = False
        for t, a in enumerate(sched):
            if a==task.id and not in_run:
                in_run = True
                start = t
            elif a!=task.id and in_run:
                in_run=False
                starts.append(start); lengths.append(t-start)
        if in_run:
            starts.append(start); lengths.append(n_steps-start)
        for s,l in zip(starts,lengths):
            ax.broken_barh([(s,l)], (10*task.id, 8), facecolors='tab:blue')
    ax.set_ylim(0, 10*len(tasks))
    ax.set_xlim(0, n_steps)
    ax.set_yticks([10*i+4 for i in range(len(tasks))])
    ax.set_yticklabels([f"Task{task.id}" for task in tasks])
    ax.set_xlabel("Time Step")
    ax.set_title("Gantt Chart of Example Schedule")
    plt.tight_layout()
    plt.savefig("gantt_example.png")
    plt.close()


    df_q_long = pd.DataFrame(util_q_pairs, columns=['utilization', 'max_q'])
    df_q_long.to_csv("qvalue_long.csv", index=False)

    plt.figure(figsize=(8,4))
    df_q_long.boxplot(column='max_q', by='utilization')
    plt.title("Q-Value Distribution by Utilization")
    plt.suptitle("")
    plt.xlabel("Utilization")
    plt.ylabel("Max Q-value")
    plt.tight_layout()
    plt.savefig("qvalue_distribution.png")
    plt.close()

  
    pd.DataFrame({
        "Episode": np.arange(len(epsilon_values)),
        "Epsilon": epsilon_values,
        "Reward": rewards_per_episode[:len(epsilon_values)]
    }).to_csv("epsilon_reward.csv", index=False)

    pd.DataFrame({
        "Step": np.arange(len(losses)),
        "Loss": losses
    }).to_csv("losses.csv", index=False)

    df_q_long.groupby("utilization")["max_q"] \
        .median() \
        .reset_index(name="MedianQ") \
        .to_csv("qvalue_by_util.csv", index=False)

    print("Saved all additional figures and CSV data files.")
    print("Saved figures in current directory:")
    print(" - training_loss_curve.png")
    print(" - episode_reward_curve.png")
    print(" - hit_rate_vs_utilization.png")
    print(" - precision_recall.png")


    print("==============EXAMPLE SCHEDULES==============")

    i = 0
    for reward, tasks_schedules in schedules_by_reward.items():
        i += 1
        print(f"\nTask set #{i}")

        tasks, schedule = tasks_schedules[np.random.choice(len(tasks_schedules))]
        print(f"utilization: {sum(task.exectime / task.period for task in tasks)}")
        print(f"reward: {reward}")

        utils.print_by_task(tasks, schedule)
        
