import os
import utils
from schedulers import rate_monotonic_schedule, earliest_deadline_first
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.stats import ttest_rel, wilcoxon

import qlearning
from qlearning import q_learning, q_learning_test

def eval_classic(tasks):

    L = utils.get_lcm_period(tasks)
    rm_hits = utils.compute_hits(tasks, rate_monotonic_schedule(tasks, L))
    edf_hits = utils.compute_hits(tasks, earliest_deadline_first(tasks, L))
    total = len(tasks) * L
    return {'RM': rm_hits / total, 'EDF': edf_hits / total}

def main():
    out_dir = os.getcwd()
    print(f"Output directory: {out_dir}")

    taskfile = 'data/taskset2.txt'
    print(f"Loading tasks from {taskfile}")
    tasks = utils.read_tasks(taskfile)
    utils.assert_is_schedulable(tasks)
    n_tasks = len(tasks)

    orig_gen = qlearning.generate_random_taskset
    qlearning.generate_random_taskset = lambda *args, **kwargs: orig_gen(
        n_tasks=n_tasks, *args, **kwargs
    )

    print("Evaluating classic schedulers...")
    classic_rates = eval_classic(tasks)

    print("Training Transformer-Enhanced DQN policy...")
    policy_net, _, _, _, _, _, _, _ = q_learning(
        n_tasksets=1000,
        n_repeat=3,
        print_status=False
    )

    print("Evaluating Transformer-Enhanced DQN policy on target set...")
    avg_list, _, _, _, _, _, _ = q_learning_test(
        policy_net,
        test_set=[tasks],
        return_schedules=False
    )
    util = round(sum(t.exectime/t.period for t in tasks)*20)/20
    dqn_single = dict(avg_list).get(util, np.nan)

    n_samples = 200
    print(f"Sampling {n_samples} sets at util={util} for significance tests...")
    rm_rates, edf_rates, dqn_rates = [], [], []
    count = 0
    while count < n_samples:
        ts = qlearning.generate_random_taskset(total_utilization=util)
        try:
            L = utils.get_lcm_period(ts)
            rm_hit  = utils.compute_hits(ts, rate_monotonic_schedule(ts, L)) / (len(ts)*L)
            edf_hit = utils.compute_hits(ts, earliest_deadline_first(ts, L)) / (len(ts)*L)
        except AssertionError:
            continue

        avg_tmp, _, _, _, _, _, _ = q_learning_test(
            policy_net,
            test_set=[ts],
            return_schedules=False
        )
        dqn_hit = dict(avg_tmp).get(util)
        if dqn_hit is None:
            continue

        rm_rates.append(rm_hit)
        edf_rates.append(edf_hit)
        dqn_rates.append(dqn_hit)
        count += 1

    df_samples = pd.DataFrame({'RM': rm_rates, 'EDF': edf_rates, 'Transformer-Enhanced DQN': dqn_rates})
    samples_csv = os.path.join(out_dir, 'hit_rate_samples.csv')
    df_samples.to_csv(samples_csv, index=False)
    print(f"Saved sample hit rates to {samples_csv}")

    t_rm, p_rm   = ttest_rel(dqn_rates, rm_rates)
    t_edf, p_edf = ttest_rel(dqn_rates, edf_rates)
    w_rm  = wilcoxon(dqn_rates, rm_rates).pvalue
    w_edf = wilcoxon(dqn_rates, edf_rates).pvalue

    df_stats = pd.DataFrame([
        {'Comparison': 'Transformer-Enhanced DQN vs RM',  'ttest_p': p_rm,  'wilcoxon_p': w_rm},
        {'Comparison': 'Transformer-Enhanced DQN vs EDF', 'ttest_p': p_edf, 'wilcoxon_p': w_edf}
    ])
    stats_csv = os.path.join(out_dir, 'significance.csv')
    df_stats.to_csv(stats_csv, index=False)
    print(f"Saved significance tests to {stats_csv}")

    df_single = pd.DataFrame([
        {'Algorithm': 'RM',                      'HitRate': classic_rates['RM']},
        {'Algorithm': 'EDF',                     'HitRate': classic_rates['EDF']},
        {'Algorithm': 'Transformer-Enhanced DQN', 'HitRate': dqn_single}
    ])
    single_csv = os.path.join(out_dir, 'compare_hit_rates.csv')
    df_single.to_csv(single_csv, index=False)
    print(f"Saved single-set comparison to {single_csv}")

    plt.figure()
    plt.bar(df_single['Algorithm'], df_single['HitRate'], edgecolor='black')
    plt.title('Hit Rate Comparison (Single Set)')
    plt.ylabel('Hit Rate')
    plt.ylim(0, 1)
    plt.tight_layout()
    bar_png = os.path.join(out_dir, 'compare_hit_rates.png')
    plt.savefig(bar_png)
    plt.close()
    print(f"Saved bar chart to {bar_png}")

    plt.figure()
    plt.boxplot([
        rm_rates,
        edf_rates,
        dqn_rates
    ], labels=['RM','EDF','Transformer-Enhanced DQN'])
    plt.title('Hit Rate Distribution Comparison')
    plt.ylabel('Hit Rate')
    plt.tight_layout()
    box_png = os.path.join(out_dir, 'hit_rate_boxplot.png')
    plt.savefig(box_png)
    plt.close()
    print(f"Saved boxplot to {box_png}")

if __name__ == '__main__':
    main()