import datetime
import math
import simpy
import random
import numpy as np
import sys
import matplotlib
matplotlib.use('TkAgg')  # Adjust based on your OS and environment
import matplotlib.pyplot as plt
import pandas as pd
#import seaborn as sns
from real_dataset.parse_datasets import load_and_prepare_data

response_times = []
response_pshort_times = []
response_plong_times = []

c1 = 0.00001 #TODO: change
c2 = 0.00005 #TODO: change

short_queue = []
long_queue = []
cheap_p_queue = []
expensive_p_queue = []

start_time = None
start_cheap_time = None
start_expensive_time = None

job_processes = {}
n_cheap_p = 0
n_expensive_p = 0
#SIMULATION_TIME = 1000000
SIMULATION_TIME = 100000
#T= float('inf') #to test FCFS
#T= 0 #to test SPRPT
#T = 4
current_job = None
current_cheap_job = None
current_expensive_job = None

LOG_EVENT_PRINT = 0
PLOT_GRAPHS = 0

#predictor = 'expP'
#predictor = 'perfectP'
#predictor = 'uniP'

#dist = 'weibull'
#dist = 'exponential'

#Real dataset
dist = 'real'
predictor = 'real'
real_data_index = 0

cheap_alpha = 0.8
expensive_alpha = 0.2

use_seperate_cheapp = 0
cheap_acc = 0.8


class EndOfDataException(Exception):
    pass

def log_event(time, event, job_id="--", size="--", predicted_size="--", notes="--", queue_content="--"):
    if LOG_EVENT_PRINT:
        time_str = f"{time:.5f}"
        job_id = "--" if job_id == "--" else f"{job_id:.2f}"
        size = "--" if size == "--" else f"{size:.2f}"
        predicted_size = "--" if predicted_size == "--" else f"{predicted_size:.2f}"
        notes = "--" if notes == "--" else "DONE" if notes == "DONE" else f"{notes:.2f}"
        print(f"{time_str:<10}| {event:<20}| {job_id:<10}| {size:<10}| {predicted_size:<15}| {notes:<20}| {queue_content}")

def short_job_size_distribution(): #TODO: change
    global T
    while True:
        sample = job_size_distribution()
        if sample < T:
            return sample

def job_size_distribution():
    global real_data_index

    if dist == 'exponential':
        return random.expovariate(1), 0

    if dist == 'weibull':
        U = random.random()
        return (-math.log(1 - U))**2 / 2, 0

    elif dist == 'real':
        # Return job size from real dataset, and handle index increment
        if real_data_index < len(real_data):
            job_size = real_data.loc[real_data_index, 'normalized_runtime']
            predicted_job_size = real_data.loc[real_data_index, 'normalized_predicted_runtime']

            real_data_index += 1
            return job_size, predicted_job_size
        else:
            raise EndOfDataException("End of data reached")

    else:
        raise ValueError("Unknown distribution type specified.")

def predict_service_time(job, uni_alpha):
    if predictor == 'perfectP' or uni_alpha == 0:
        return job

    if predictor == 'expP':
        return random.expovariate(1 / job)

    if predictor == 'uniP':
        lower_bound = (1 - uni_alpha) * job
        upper_bound = (1 + uni_alpha) * job
        return random.uniform(lower_bound, upper_bound)

def is_service_time_less_than_T(true_service_time, T, accuracy_prob):
    # Determine the actual condition
    actual_condition = true_service_time < T

    # Determine if the prediction should be accurate
    if random.random() <= accuracy_prob:
        prediction = actual_condition
        is_accurate = True
    else:
        prediction = not actual_condition
        is_accurate = False

    return prediction, is_accurate

############### Different types of predictors  #################

#def predict_service_time(job): #TODO: perfect predictor
#    return job

#def predict_service_time(job): #TODO uniP predictor
#    lower_bound = (1 - alpha) * job
#    upper_bound = (1 + alpha) * job
#    return random.uniform(lower_bound, upper_bound)

#def predict_service_time(z): #TODO: exponential predictor
#     return random.expovariate(1/z)

#def predict_uniP_cheap(job): #TODO cheap uniP predictor
#    global cheap_alpha
#    lower_bound = (1 - cheap_alpha) * job
#    upper_bound = (1 + cheap_alpha) * job
#    return random.uniform(lower_bound, upper_bound)

#def predict_uniP_expensive(job):  # TODO expensive uniP predictor
#    global expensive_alpha
#    lower_bound = (1 - expensive_alpha) * job
#    upper_bound = (1 + expensive_alpha) * job
#    return random.uniform(lower_bound, upper_bound)
#####################################################

def remaining_time(job):
    age = job[2]
    return job[1] - age

def check_schedule(env, server):
    global current_job, current_cheap_job, current_expensive_job, start_time, start_cheap_time, start_expensive_time
    global cheap_p_queue, expensive_p_queue, short_queue, long_queue

    cheap_p_str = [f"({x:.2f}, {y:.2f})" for x, y in cheap_p_queue]
    short_queue_str = [f"({x:.2f}, {y:.2f})" for x, y in short_queue]
    expensive_p_str = [f"({x:.2f}, {y:.2f}, {z:.2f}, {w:.2f})" for x, y, z, w in expensive_p_queue]
    long_queue_str = [f"({x:.2f}, {y:.2f}, {z:.2f}, {w:.2f})" for x, y, z, w in long_queue]
    queue_str = f"Cheap_p: {cheap_p_str}, Short: {short_queue_str}, Expensive_p: {expensive_p_str}, Long: {long_queue_str}"
    log_event(env.now, "Checking Policy", "--", "--", "--", "--", queue_str)

    if short_queue:
        log_event(env.now, "Schedule Short Job", "--", "--", "--", "--", "--")
        env.process(serve_job(env, "short_queue", server))
    elif cheap_p_queue:
        log_event(env.now, "Schedule Cheap p", "--", "--", "--", "--", "--")
        env.process(serve_job(env, "cheap_p_queue", server))
    elif expensive_p_queue:
        log_event(env.now, "Schedule Expensice p", "--", "--", "--", "--", "--")
        env.process(serve_job(env, "expensive_p_queue", server))
    elif long_queue:
        log_event(env.now, "Schedule long Job", "--", "--", "--", "--", "--")
        env.process(serve_job(env, "long_queue", server))


def serve_job(env, type, server):
    global current_job, current_cheap_job, current_expensive_job, start_cheap_time, start_expensive_time, start_time, n_cheap_p, n_expensive_p
    global cheap_p_queue, expensive_p_queue, short_queue, long_queue
    global cheap_acc, use_seperate_cheapp

    if type=="short_queue":
        current_process = job_processes.get(start_time)
        if current_process and current_process.is_alive:
            current_process.interrupt("Short Job Arrival")

        with server.request(priority=0) as req:
            yield req
            if not short_queue:
                return
            next_job = short_queue.pop(0)
            log_event(env.now, "Serve Short", next_job[0], next_job[1], "--", next_job[0], "--")
            yield env.timeout(next_job[1])
            log_event(env.now, "Short Job Done", next_job[0], "--", "--", "DONE", f"S-response:{env.now - next_job[0]}")
            response_times.append(env.now - next_job[0])
            response_pshort_times.append(env.now - next_job[0])

        check_schedule(env, server)

    elif type=="cheap_p_queue":
        current_process = job_processes.get(start_time)
        if current_process and current_process.is_alive:
            current_process.interrupt("Cheap P Arrival")

        with server.request(priority=1) as req:
            yield req
            if not cheap_p_queue:
                return
            next_job = cheap_p_queue.pop(0)
            log_event(env.now, "Serve Cheap p", next_job[0],  "--", "--", next_job[0], f"Age: {next_job[1]}")
            current_job = next_job
            start_time = env.now
            remaining_cheapp = c1 - next_job[1]
            try:
                yield env.timeout(remaining_cheapp)
                log_event(env.now, "Cheap p Done", next_job[0], "--", "--", "DONE", "--")
                try:
                    job_size, cheap_predicted_service_time = job_size_distribution()
                except EndOfDataException:
                    print("End of data reached. Terminating simulation.")
                    return

                if dist == 'exponential' or dist == 'weibull':
                    cheap_predicted_service_time = predict_service_time(job_size, cheap_alpha)
                n_cheap_p += 1

                if use_seperate_cheapp:
                    prediction, truth = is_service_time_less_than_T(job_size, T, cheap_acc)
                    is_short = prediction
                else:
                    is_short = cheap_predicted_service_time < T

                if is_short:
                    short_queue.append((next_job[0], job_size))
                    log_event(env.now, "Append to Short", next_job[0], "--", "--", "--", "--")

                else:
                    n_expensive_p += 1
                    if dist == 'exponential' or dist == 'weibull':
                        predicted_service_time = predict_service_time(job_size, expensive_alpha)
                    else:
                        predicted_service_time = cheap_predicted_service_time

                    expensive_p_queue.append((next_job[0], predicted_service_time, 0, job_size))
                    log_event(env.now, "Append to Expensive", next_job[0], "--", predicted_service_time, "--", "--")
                check_schedule(env, server)
                current_job = None
                start_time = None

            except simpy.Interrupt as interrupt:
                if str(interrupt.cause) == "Short Job Arrival":
                    elapsed_time = env.now - start_time
                    updated_job = (current_job[0], current_job[1] + elapsed_time)
                    log_event(env.now, "Short Preempted cheap_p", job_id=current_job[0])
                    if start_time in job_processes:
                        del job_processes[start_time]
                    cheap_p_queue.append(updated_job)
                    current_job = None
                    start_time = None


    elif type=="expensive_p_queue":
        current_process = job_processes.get(start_time)
        if current_process and current_process.is_alive:
            current_process.interrupt("Expensive P Arrival")

        with server.request(priority=2) as req:
            yield req
            if not expensive_p_queue:
                return
            next_job = expensive_p_queue.pop(0)
            log_event(env.now, "Serve Expensive p", next_job[0],  "--", "--", next_job[0], f"Age: {next_job[2]}")
            current_job = next_job
            remaining_expensivep = c2 - next_job[2]
            try:
                yield env.timeout(remaining_expensivep)

                log_event(env.now, "Expensive p Done", next_job[0], "--", "--", "DONE", "--")
                n_expensive_p += 1
                long_queue.append((next_job[0], next_job[1], 0, next_job[3]))
                check_schedule(env, server)
                current_job = None
                start_time = None

            except simpy.Interrupt as interrupt:
                if str(interrupt.cause) == "Cheap P Arrival":
                    elapsed_time = env.now - start_time
                    updated_job = (current_job[0], current_job[1], current_job[2] + elapsed_time, current_job[3])
                    log_event(env.now, "Cheap_p Preempted Expensive_p", job_id=current_job[0])
                    expensive_p_queue.append(updated_job)
                    current_job = None
                    start_time = None
                if str(interrupt.cause) == "Short Job Arrival":
                    elapsed_time = env.now - start_time
                    updated_job = (current_job[0], current_job[1], current_job[2] + elapsed_time, current_job[3])
                    log_event(env.now, "Short Preempted Expensive_p", job_id=current_job[0])
                    expensive_p_queue.append(updated_job)
                    current_job = None
                    start_time = None



    elif type=="long_queue":

        long_queue.sort(key=remaining_time)


        if current_job: #think to seperate current_jobs of each type
            age_received = env.now - start_time
            #print(f'current_job: {current_job}')
            if (current_job[1] - (current_job[2] + age_received)) > remaining_time(long_queue[0]): #remaining time current > remaining time of new
                updated_job = (current_job[0], current_job[1], current_job[2] + age_received, current_job[3])
                log_event(env.now, "Long Preempted Long", updated_job[0], updated_job[1], "--", updated_job[0], f"Age: {updated_job[2]}")
                long_queue.append(updated_job)
                long_queue.sort(key=remaining_time)
                current_process = job_processes.get(start_time)
                if current_process:
                    current_process.interrupt()

        with server.request(priority=3) as req:
            yield req

            if not long_queue:
                return
            next_job = long_queue.pop(0)
            log_event(env.now, "Serve Long", next_job[0], next_job[1], "--", next_job[0], f"Age: {next_job[2]}")
            start_time = env.now
            current_job = (next_job[0], next_job[1], next_job[2], next_job[3])
            current_process = env.process(serve_long_job(env, next_job, server))
            job_processes[start_time] = current_process
            try:
                yield current_process
                response_time = env.now - current_job[0]
                log_event(env.now, "Long Job Done", current_job[0], "--", "--", "DONE", f"L-response:{response_time}")

                job_details = {
                    'response_time': response_time,
                    'actual_size': current_job[3],
                    'predicted_size': current_job[1]
                }
                #response_plong_times_details.append(job_details)
                response_times.append(response_time)
                response_plong_times.append(response_time)
                current_job = None
                if start_time in job_processes:
                    del job_processes[start_time]
                start_time = None

            except simpy.Interrupt as interrupt:
                if str(interrupt.cause) == "Cheap P Arrival":
                    elapsed_time = env.now - start_time
                    updated_job = (current_job[0], current_job[1], current_job[2] + elapsed_time, current_job[3])
                    log_event(env.now, "Cheap Preempted Long", updated_job[0], updated_job[1], "--", updated_job[0], f"Age: {updated_job[2]}")
                    if start_time in job_processes:
                        del job_processes[start_time]

                    long_queue.append(updated_job)
                    long_queue.sort(key=remaining_time)
                    current_job = None
                    start_time = None
                    check_schedule(env, server)

                if str(interrupt.cause) == "Short Job Arrival":
                    elapsed_time = env.now - start_time
                    updated_job = (current_job[0], current_job[1], current_job[2] + elapsed_time, current_job[3])
                    log_event(env.now, "Short Preempted Long", updated_job[0], updated_job[1], "--", updated_job[0], f"Age: {updated_job[2]}")
                    if start_time in job_processes:
                        del job_processes[start_time]

                    long_queue.append(updated_job)
                    long_queue.sort(key=remaining_time)
                    current_job = None
                    start_time = None
                    check_schedule(env, server)

                if str(interrupt.cause) == "Expensive P Arrival":
                    elapsed_time = env.now - start_time
                    updated_job = (current_job[0], current_job[1], current_job[2] + elapsed_time, current_job[3])
                    log_event(env.now, "ExpensiveP Preempted Long", updated_job[0], updated_job[1], "--", updated_job[0], f"Age: {updated_job[2]}")
                    if start_time in job_processes:
                        del job_processes[start_time]

                    long_queue.append(updated_job)
                    long_queue.sort(key=remaining_time)
                    current_job = None
                    start_time = None
                    check_schedule(env, server)

        check_schedule(env, server)



def serve_long_job(env, job, server):
    actual_remaining_time = job[3] - job[2]
    yield env.timeout(actual_remaining_time)
    check_schedule(env, server)

def job_generator(env, server, arrival_rate):
    global n_cheap_p, n_expensive_p
    global cheap_p_queue, expensive_p_queue, short_queue, long_queue

    while True:
        yield env.timeout(random.expovariate(arrival_rate))
        job = env.now

        cheap_p_queue.append((job, 0)) #(job, age) -- age for cheap prediction since it could preempt
        log_event(env.now, "Append to Cheap p", job, "--", "--", "--", "--")
        check_schedule(env, server)


def dummy_job_generator(env, server):
    global n_cheap_p, n_expensive_p

    # Job 1 - Predicted Short
    yield env.timeout(1)  # Job arrives at time 1
    job1 = env.now
    cheap_p_queue.append(job1)
    log_event(env.now, "Append to Cheap p", job1, "--", "--", "--", "--")
    check_schedule(env, server)

    # Job 2 - Predicted Long, goes through Expensive Prediction
    yield env.timeout(0.5)  # Job arrives at time 3
    job2 = env.now
    cheap_p_queue.append(job2)
    log_event(env.now, "Append to Cheap p", job2, "--", "--", "--", "--")
    check_schedule(env, server)

    # Job 3 - Also starts with Cheap Prediction but predicted as Long
    yield env.timeout(0.5)  # Job arrives at time 6
    job3 = env.now
    cheap_p_queue.append(job3)
    log_event(env.now, "Append to Cheap p", job3, "--", "--", "--", "--")
    check_schedule(env, server)


def run_simulation(arrival_rate, threshold, c1_price, c2_price, test_cheap_alpha= 0.8, test_expen_alpha=0.2):
    global response_times, response_pshort_times, response_plong_times, response_plong_times_details
    global start_time, job_processes, SIMULATION_TIME
    global T
    global use_seperate_cheapp
    global n_cheap_p, n_expensive_p
    global cheap_p_queue, expensive_p_queue, short_queue, long_queue
    global cheap_alpha, expensive_alpha, cheap_acc
    global current_job
    global c1, c2
    # Re-initialize the lists at the start of each run
    response_times = []
    response_pshort_times = []
    response_plong_times = []

    short_queue = []
    long_queue = []
    cheap_p_queue = []
    expensive_p_queue = []

    c1 = c1_price
    c2 = c2_price

    T = threshold
    n_cheap_p = 0
    n_expensive_p = 0

    expensive_alpha = test_expen_alpha

    if use_seperate_cheapp:
        cheap_acc = test_cheap_alpha
    else:
        cheap_alpha = test_cheap_alpha

    start_time = None
    job_processes = {}

    if threshold == 0: #SPRPT
        c1 = 0
    if threshold == float('inf'): #FCFS
        c1 = 0
        c2 = 0

    print(f'Simulating M/M/1 queue with server cost, lambda: {arrival_rate}, T:{T}, c1:{c1}, c2:{c2}')
    env = simpy.Environment()
    server = simpy.PriorityResource(env, capacity=1)
    current_job = None
    print(f"{'TIME':<10} | {'EVENT':<18} | {'JOB ID':<10} | {'SIZE':<8} | {'PREDICTED SIZE':<15} | {'SERVER':<18} | {'NOTES'}")
    print("-" * 145)
    env.process(job_generator(env, server, arrival_rate))

    env.run(until=SIMULATION_TIME)
    #print(f'response_times: {response_times}')
    mean_response_time_pshort = 0
    mean_response_time_plong = 0
    if LOG_EVENT_PRINT:
        print(f'response_times: {response_times}')
        print(f'n_cheap_p: {n_cheap_p}')
        print(f'n_expensive_p: {n_expensive_p}')


    if threshold !=0 and threshold != float('inf'):
        mean_response_time = sum(response_times) / len(response_times)
        print(f"\nMean Response Time with costs: {mean_response_time:.2f} time units lambda: {arrival_rate}, T:{T}")
        mean_response_time_pshort = (sum(response_pshort_times)) / len(response_pshort_times)
        print(f"\nPredicted short: Mean Response Time with costs: {mean_response_time_pshort:.2f} time units lambda: {arrival_rate}, T:{T}")

        mean_response_time_plong = (sum(response_plong_times)) / len(response_plong_times)
        print(f"\nPredicted long: Mean Response Time with costs: {mean_response_time_plong:.2f} time units lambda: {arrival_rate}, T:{T}")


    if threshold == 0: #SPRPT, no short jobs
        mean_response_time = (sum(response_times)) / len(response_times)
        print(f"\nMean Response Time with costs: {mean_response_time:.2f} time units lambda: {arrival_rate}, T:{T}")

        mean_response_time_plong = (sum(response_plong_times)) / len(response_plong_times)
        print(f"\nPredicted long: Mean Response Time with costs: {mean_response_time_plong:.2f} time units lambda: {arrival_rate}, T:{T}")


    if threshold == float('inf'): #FCFS-- no need for prediction, no long jobs #TODO!!! NO NEED TO WAIT FOR C1 PREDICTION
        mean_response_time = (sum(response_times)) / len(response_times)
        print(f"\nMean Response Time with costs: {mean_response_time:.2f} time units lambda: {arrival_rate}, T:{T}")


        mean_response_time_pshort = (sum(response_pshort_times)) / len(response_pshort_times)
        print(f"\nPredicted short: Mean Response Time with costs: {mean_response_time_pshort:.2f} time units lambda: {arrival_rate}, T:{T}")


    return mean_response_time, mean_response_time_pshort, mean_response_time_plong


def simulation_wrapper(arrival_rate, threshold, c1, c2, test_cheap_alpha= 0.8, test_expen_alpha=0.2):
    global response_times, response_pshort_times, response_plong_times, short_queue, long_queue, cheap_p_queue, expensive_p_queue, job_processes, start_time, n_cheap_p, n_expensive_p

    response_times = []
    response_pshort_times = []
    response_plong_times = []
    short_queue = []
    long_queue = []
    cheap_p_queue = []
    expensive_p_queue = []
    job_processes = {}

    # Initialize lists to store results
    mean_response_times = []
    mean_response_times_pshort = []
    mean_response_times_plong = []
    start_time = None
    n_cheap_p = 0
    n_expensive_p = 0

    if threshold == 0: #SPRPT
        c1 = 0
    if threshold == float('inf'): #FCFS
        c1 = 0
        c2 = 0
    # Run the simulation 100 times
    for _ in range(100):
        mean_response_time, mean_response_time_pshort, mean_response_time_plong = run_simulation(arrival_rate, threshold, c1, c2, test_cheap_alpha, test_expen_alpha)
        mean_response_times.append(mean_response_time)
        mean_response_times_pshort.append(mean_response_time_pshort)
        mean_response_times_plong.append(mean_response_time_plong)

    # Calculate and print the average of each
    average_mean_response_time = sum(mean_response_times) / len(mean_response_times)
    average_mean_response_time_pshort = sum(mean_response_times_pshort) / len(mean_response_times_pshort)
    average_mean_response_time_plong = sum(mean_response_times_plong) / len(mean_response_times_plong)

    print(f"Average Mean Response Time: {average_mean_response_time:.2f} time units")
    print(f"Average Predicted Short Mean Response Time: {average_mean_response_time_pshort:.2f} time units")
    print(f"Average Predicted Long Mean Response Time: {average_mean_response_time_plong:.2f} time units")


    current_date = datetime.datetime.now().strftime("%Y-%m-%d")
    #filename = f'res/serverCost_simulation_res_{current_date}.csv'

    #with open(filename, 'w') as file:
    #    file.write(f"arrival_rate:{arrival_rate}, T:{T}, c1:{c1}, c2:{c2}\n")
    #    file.write(f"mean_response_time:{mean_response_time}, mean_response_time_pshort:{mean_response_time_pshort}, mean_response_time_plong:{mean_response_time_plong}\n")

    return average_mean_response_time, average_mean_response_time_pshort, average_mean_response_time_plong
########

markers = ['o', 'x', '^', 's', 'D', 'p']  # You can add more markers as needed
colors = ['b', 'g', 'r', 'c', 'm', 'y']  # Basic color abbreviations: b-blue, g-green, r-red, etc.

def test_cost_vs_ratio():
    global real_data_index

    # Test parameters
    ratio_values = [1, 2, 3, 5, 8]
    T_values = [0, float('inf'), 1]
    default_arrival_rate = 0.9
    default_c1 = 0.01
    labels = {0:'SPRPT', float('inf'): 'FCFS', 1:'SkipPredict'}

    if dist == 'real':
        T_values = [0, float('inf'), 4]
        labels = {0: 'SPRPT', float('inf'): 'FCFS', 4: 'SkipPredict'}
        ratio_values = [8, 10, 15, 20, 30]


    results = []
    for t in T_values:
        for ratio in ratio_values:
            real_data_index = 0

            avg_mean, avg_pshort, avg_plong = simulation_wrapper(default_arrival_rate, t, default_c1, ratio * default_c1)
            results.append({
                'Alg': labels[t],
                'Ratio': ratio,
                'Average Mean Response Time': avg_mean,
                'Average PShort Response Time': avg_pshort,
                'Average PLong Response Time': avg_plong,
                'Default arrival rate': default_arrival_rate,
                'Default c1': default_c1,
                'Default T': 1
            })

    # Saving the results to a CSV file
    results_df = pd.DataFrame(results)
    results_df.to_csv(f'res/servercost_cost_vs_ratio_results_{predictor}_alpha_{cheap_alpha}_{expensive_alpha}_dist_{dist}_ext.csv', index=False)

    # # Plotting the results

    if PLOT_GRAPHS:
        unique_labels = results_df['Alg'].unique()

        for label, marker, color in zip(unique_labels, markers, colors):
            df_subset = results_df[results_df['Alg'] == label]
            plt.plot(df_subset['Ratio'], df_subset['Average Mean Response Time'], marker=marker, color=color,
                     label=label)

        plt.xlabel('Prices Ratio')
        plt.ylabel('Cost')
        plt.legend()
        plt.grid(True)
        plt.savefig(f'graphs/servercost_cost_vs_ratio_{predictor}_alpha_{cheap_alpha}_{expensive_alpha}_dist_{dist}_ext.png')
        plt.clf()


def test_cost_vs_T():
    global real_data_index
    T_values = [0, float('inf'), 1]
    T_test_values = [0.1, 0.5, 1, 1.5, 2, 4, 5, 8]
    default_arrival_rate = 0.9
    fixed_c1 = 0.01
    fixed_c2 = 0.05

    labels = {0: 'SPRPT', float('inf'): 'FCFS', 1: 'SkipPredict'}

    if dist == 'real':
        T_values = [0, float('inf'), 4]
        labels = {0: 'SPRPT', float('inf'): 'FCFS', 4: 'SkipPredict'}

    # Running the test
    results = []
    for t in T_values:
        if t in [0, float('inf')]:
            real_data_index = 0

            average_mean_response_time, average_mean_response_time_pshort, average_mean_response_time_plong = simulation_wrapper(
                default_arrival_rate, t, fixed_c1, fixed_c2)
            results.append({
                'Alg': labels[t],
                'T Value': t,
                'Average Mean Response Time': average_mean_response_time,
                'Average PShort Response Time': average_mean_response_time_pshort,
                'Average PLong Response Time': average_mean_response_time_plong,
                'Default arrival rate': default_arrival_rate,
                'Default c1': fixed_c1,
                'Default c2': fixed_c2
            })
        else:
            for test_value in T_test_values:
                real_data_index = 0
                average_mean_response_time, average_mean_response_time_pshort, average_mean_response_time_plong = simulation_wrapper(default_arrival_rate, test_value, fixed_c1, fixed_c2)

                results.append({
                    'Alg': labels[t],
                    'T Value': test_value,
                    'Average Mean Response Time': average_mean_response_time,
                    'Average PShort Response Time': average_mean_response_time_pshort,
                    'Average PLong Response Time': average_mean_response_time_plong,
                    'Default arrival rate': default_arrival_rate,
                    'Default c1': fixed_c1,
                    'Default c2': fixed_c2
                })

    # Saving the results to a CSV file
    results_df = pd.DataFrame(results)
    results_df.to_csv(f'res/servercost_cost_vs_T_results_{predictor}_alpha_{cheap_alpha}_{expensive_alpha}_dist_{dist}_dataset_{dataset}.csv', index=False)

    if PLOT_GRAPHS:
        # Plotting the results
        for (t, marker, color) in zip(T_values, markers, colors):
            label_t = labels[t]
            df_subset = results_df[results_df['Alg'] == label_t]

            plt.plot(df_subset['T Value'], df_subset['Average Mean Response Time'], marker=marker, color=color,
                     label=label_t)

        plt.xlabel(r'$T$')
        plt.ylabel('Cost')
        plt.legend()
        plt.grid(True)
        plt.savefig(f'graphs/servercost_cost_vs_T_{predictor}_alpha_{cheap_alpha}_{expensive_alpha}_dist_{dist}_ext.png')
        plt.clf()


def test_cost_vs_arrivalrate():
    global real_data_index

    # Test parameters
    arrival_rate_values = [0.5, 0.6, 0.7, 0.9, 0.95]
    T_values = [0, float('inf'), 1]
    fixed_c1 = 0.01
    fixed_c2 = 0.05
    labels = {0: 'SPRPT', float('inf'): 'FCFS', 1: 'SkipPredict'}

    if dist == 'real':
        T_values = [0, float('inf'), 4]
        labels = {0: 'SPRPT', float('inf'): 'FCFS', 4: 'SkipPredict'}


    # Running the test
    results = []
    for t in T_values:
        for arrival_rate in arrival_rate_values:
            real_data_index = 0

            avg_mean, avg_pshort, avg_plong = simulation_wrapper(arrival_rate, t, fixed_c1, fixed_c2)
            results.append({
                'Alg': labels[t],
                'Arrival Rate': arrival_rate,
                'Average Mean Response Time': avg_mean,
                'Average PShort Response Time': avg_pshort,
                'Average PLong Response Time': avg_plong,
                'Default T': t,
                'Default c1': fixed_c1,
                'Default c2': fixed_c2
            })


    # Saving the results to a CSV file
    results_df = pd.DataFrame(results)
    results_df.to_csv(f'res/servercost_cost_vs_arrivalrate_results_{predictor}_alpha_{cheap_alpha}_{expensive_alpha}_dist_{dist}_dataset_{dataset}.csv', index=False)

    if PLOT_GRAPHS:
        for (t, marker, color) in zip(T_values, markers, colors):
            label_t = labels[t]
            df_subset = results_df[results_df['Alg'] == label_t]

            plt.plot(df_subset['Arrival Rate'], df_subset['Average Mean Response Time'], marker=marker, color=color,
                     label=label_t)

        plt.xlabel('Arrival Rate')
        plt.ylabel('Cost')
        plt.legend()
        plt.grid(True)
        plt.savefig(f'graphs/servercost_cost_vs_arrivalrate_{predictor}_alpha_{cheap_alpha}_{expensive_alpha}_ext.png')
        plt.clf()

# def test_cost_vs_accuracy():
#     cheap_alpha_prices = {0.05: 0.4, 0.1: 0.35, 0.3: 0.28, 0.5: 0.23, 0.8: 0.225, 0.9: 0.22}
#     expensive_alpha_prices = {0.05: 0.4, 0.1: 0.35, 0.3: 0.28, 0.5: 0.23, 0.8: 0.225, 0.9: 0.22}
#
#
#     # Test parameters
#     T_values = [1]  # Only using SkipPredict for simplicity
#     default_arrival_rate = 0.9
#     labels = {1: 'SkipPredict'}
#
#     results = []
#     for t in T_values:
#
#         for calpha, cprice in cheap_alpha_prices.items():
#             for exp_alpha, eprice in expensive_alpha_prices.items():
#                 avg_mean, avg_pshort, avg_plong = simulation_wrapper(default_arrival_rate, t, cprice,
#                                                                      eprice, calpha, exp_alpha)
#                 results.append({
#                     'Algorithm': labels[t],
#                     'Cheap Alpha': calpha,
#                     'Expensive Alpha': exp_alpha,
#                     'Average Mean Response Time': avg_mean,
#                     'Average PShort Response Time': avg_pshort,
#                     'Average PLong Response Time': avg_plong,
#                     'Cheap Alpha Cost': cprice,
#                     'Expensive Alpha Cost': eprice,
#                     'Default Arrival Rate': default_arrival_rate
#                 })
#
#     # Saving the results to a CSV file
#     results_df = pd.DataFrame(results)
#     filename = f'res/servercost_cost_vs_alpha_results.csv'
#     results_df.to_csv(filename, index=False)
#
#     # # Plotting the results
#
#     if PLOT_GRAPHS:
#         results_df = pd.read_csv(f'res/cost_vs_alpha_results.csv')
#
#         pivot_table = results_df.pivot(index='Cheap Alpha', columns='Expensive Alpha', values='Average Mean Response Time')
#
#         row_labels = [f"{row} (c1: {results_df[results_df['Cheap Alpha'] == row]['Cheap Alpha Cost'].values[0]})" for row in
#                       pivot_table.index]
#
#         col_labels = [f"{col} (c2: {results_df[results_df['Expensive Alpha'] == col]['Expensive Alpha Cost'].values[0]})" for col in
#                       pivot_table.columns]
#
#         # Create the heat matrix plot
#         plt.figure(figsize=(12, 10))
#         sns.heatmap(pivot_table, annot=True, cmap='YlGnBu', xticklabels=col_labels, yticklabels=row_labels)
#         plt.xlabel(r'Expensive $\alpha$', fontsize=14)
#         plt.ylabel(r'Cheap $\alpha$', fontsize=14)
#         plt.savefig(f'graphs/servercost_cost_vs_acc_ext.png')
#         plt.clf()


def test_cost_vs_accuracy():
    cheap_acc_prices = {0.05: 0.008,  0.5: 0.15,  0.9: 0.3}
    expensive_alpha_prices = {0.05: 0.5, 0.6: 0.25, 0.9: 0.22}

# Test parameters
    T_values = [1]  # Only using SkipPredict for simplicity
    default_arrival_rate = 0.9
    labels = {1: 'SkipPredict'}

    results = []
    for t in T_values:
        for calpha, cprice in cheap_acc_prices.items():
            for exp_alpha, eprice in expensive_alpha_prices.items():
                avg_mean, avg_pshort, avg_plong = simulation_wrapper(default_arrival_rate, t, cprice,
                                                                     eprice, calpha, exp_alpha)
                results.append({
                    'Algorithm': labels[t],
                    'Cheap Alpha': calpha,
                    'Expensive Alpha': exp_alpha,
                    'Average Mean Response Time': avg_mean,
                    'Average PShort Response Time': avg_pshort,
                    'Average PLong Response Time': avg_plong,
                    'Cheap Alpha Cost': cprice,
                    'Expensive Alpha Cost': eprice,
                    'Default Arrival Rate': default_arrival_rate
                })

    # Saving the results to a CSV file
    results_df = pd.DataFrame(results)
    filename = f'res/servercost_cost_vs_alpha_results.csv'
    results_df.to_csv(filename, index=False)

    # # Plotting the results

    if PLOT_GRAPHS:
        results_df = pd.read_csv(f'res/servercost_cost_vs_alpha_results.csv')

        # Create a pivot table with Cheap Alpha as rows, Expensive Alpha as columns, and Average Mean Response Time as values
        pivot_table = results_df.pivot(index='Cheap Alpha', columns='Expensive Alpha', values='Average Mean Response Time')

        # Create the row labels with Cheap Alpha and its cost
        row_labels = [f"{row} (c1: {results_df[results_df['Cheap Alpha'] == row]['Cheap Alpha Cost'].values[0]})" for row in
                      pivot_table.index]

        col_labels = [
            f"{1 - col:.1f} (c2: {results_df[results_df['Expensive Alpha'] == col]['Expensive Alpha Cost'].values[0]})"
            for col in pivot_table.columns]

        # Create the heat matrix plot
        plt.figure(figsize=(12, 10))
        ax = sns.heatmap(pivot_table, annot=True, cmap='YlGnBu', xticklabels=col_labels, yticklabels=row_labels)
        ax.set_xticklabels(ax.get_xticklabels(), fontsize=15)
        ax.set_yticklabels(ax.get_yticklabels(), fontsize=15)

        plt.xlabel(r'$1-\alpha$ (Expensive Prediction)', fontsize=18)
        plt.ylabel(r'Probability (Accurate Cheap Prediction)', fontsize=18)
        plt.savefig(f'graphs/servercost_vs_acc_ext_arrival_{default_arrival_rate}.png')
        plt.clf()

if __name__ == "__main__":
    datasets = ['twosigma', 'google', 'trinity']


    for dataset in datasets:
        if dataset == 'twosigma':
            file_path = 'real_dataset/jvupredict_twosigma.csv.gz'
        if dataset == 'google':
            file_path = 'real_dataset/jvupredict_google_all_features.csv.gz'
        if dataset == 'mustang':
            file_path = 'real_dataset/jvupredict_mustang_full.csv.gz'
        if dataset == 'trinity':
            file_path = 'real_dataset/jvupredict_trinity.csv.gz'

        real_data = load_and_prepare_data(file_path)

        #simulation_wrapper(0.7, 0, 0, 0)

        #simulation_wrapper(0.5, 0, 0, 0)
        #simulation_wrapper(0.5, float('inf') , 0, 0)

        #TESTS
        test_cost_vs_ratio()
        test_cost_vs_T()
        test_cost_vs_arrivalrate()

