import datetime
import math

import simpy
import random
import numpy as np
import sys
import matplotlib
matplotlib.use('TkAgg')  # Adjust based on your OS and environment
import matplotlib.pyplot as plt
import pandas as pd

response_times = []
response_pshort_times = []
response_plong_times = []

initial_queue = []
long_queue = []
expensive_p_queue = []

start_time = None
job_processes = {}
n_expensive_p = 0
#response_plong_times_details = []


SIMULATION_TIME = 1000000

L = 1
c2 = 0.00005 #TODO: change

current_job = None
LOG_EVENT_PRINT = 1
PLOT_GRAPHS = 0

predictor = 'expP'
#predictor = 'perfectP'
#predictor = 'uniP'


#dist = 'weibull'
dist = 'exponential'

cheap_alpha = 0.8
expensive_alpha = 0.01

alpha_cost_mapping = {
    0: 2,    # When alpha is 0 (perfect predictor), cost is 2
    0.1: 1.85,
    0.3: 1.55,
    0.5: 1.25,
    0.8: 0.75,
    1: 0.5
}

def log_event(time, event, job_id="--", size="--", predicted_size="--", notes="--", queue_content="--"):
    if LOG_EVENT_PRINT:
        time_str = f"{time:.5f}"
        job_id = "--" if job_id == "--" else f"{job_id:.2f}"
        size = "--" if size == "--" else f"{size:.2f}"
        predicted_size = "--" if predicted_size == "--" else f"{predicted_size:.2f}"
        notes = "--" if notes == "--" else "DONE" if notes == "DONE" else f"{notes:.2f}"
        print(f"{time_str:<10}| {event:<20}| {job_id:<10}| {size:<10}| {predicted_size:<15}| {notes:<20}| {queue_content}")



def short_job_size_distribution(): #TODO: change
    global T
    while True:
        sample = job_size_distribution()
        if sample < T:
            return sample

#def job_size_distribution():
#    return random.expovariate(1)


#def job_size_distribution():
#    U = random.random()
#    return (-math.log(1 - U))**2 / 2

def job_size_distribution():
    if dist == 'exponential':
        return random.expovariate(1)

    if dist == 'weibull':
        U = random.random()
        return (-math.log(1 - U))**2 / 2

def predict_service_time(job, uni_alpha):
    if predictor == 'perfectP' or uni_alpha == 0:
        return max(job, L)

    if predictor == 'expP':
        return max(random.expovariate(1 / job), L)

    if predictor == 'uniP':
        lower_bound = (1 - uni_alpha) * job
        upper_bound = (1 + uni_alpha) * job
        return max(random.uniform(lower_bound, upper_bound), L)

############### Different types of predictors  #################

#def predict_service_time(job): #TODO: perfect predictor
#    return job

#def predict_service_time(job): #TODO uniP predictor
#    lower_bound = (1 - alpha) * job
#    upper_bound = (1 + alpha) * job
#    return random.uniform(lower_bound, upper_bound)

#def predict_service_time(z): #TODO: exponential predictor
#     return random.expovariate(1/z)

#def predict_uniP_cheap(job): #TODO cheap uniP predictor
#    global cheap_alpha
#    lower_bound = (1 - cheap_alpha) * job
#    upper_bound = (1 + cheap_alpha) * job
#    return random.uniform(lower_bound, upper_bound)

#def predict_uniP_expensive(job):  # TODO expensive uniP predictor
#    global expensive_alpha
#    lower_bound = (1 - expensive_alpha) * job
#    upper_bound = (1 + expensive_alpha) * job
#    return random.uniform(lower_bound, upper_bound)
#####################################################

def remaining_time(job):
    age = job[2]
    return job[1] - age



def check_schedule(env, server):
    global response_times, response_pshort_times, response_plong_times, response_plong_times_details
    global start_time, job_processes, SIMULATION_TIME
    global L
    global n_expensive_p
    global initial_queue, expensive_p_queue, long_queue
    global current_job
    global expensive_alpha



    initial_queue_str = [f"({x:.2f}, {y:.2f})" for x, y in initial_queue]
    expensive_p_str = [f"({x:.2f}, {y:.2f}, {z:.2f}, {w:.2f})" for x, y, z, w in expensive_p_queue]
    long_queue_str = [f"({x:.2f}, {y:.2f}, {z:.2f}, {w:.2f})" for x, y, z, w in long_queue]
    queue_str = f"Intial: {initial_queue_str}, Expensive_p: {expensive_p_str}, Long: {long_queue_str}"
    log_event(env.now, "Checking Policy", "--", "--", "--", "--", queue_str)

    if initial_queue:
        log_event(env.now, "Schedule Initial Job", "--", "--", "--", "--", "--")
        env.process(serve_job(env, "initial_queue", server, preemptive=False))
    elif expensive_p_queue:
        log_event(env.now, "Schedule Expensive p", "--", "--", "--", "--", "--")
        env.process(serve_job(env, "expensive_p_queue", server, preemptive=True))
    elif long_queue:
        log_event(env.now, "Schedule long Job", "--", "--", "--", "--", "--")
        env.process(serve_job(env, "long_queue", server, preemptive=True))



def serve_job(env, type, server, preemptive):
    global response_times, response_pshort_times, response_plong_times, response_plong_times_details
    global start_time, job_processes, SIMULATION_TIME
    global L
    global n_expensive_p
    global initial_queue, expensive_p_queue, long_queue
    global current_job
    global expensive_alpha
    global c2


    if type=="initial_queue":
        current_process = job_processes.get(start_time)
        if current_process and current_process.is_alive:
            current_process.interrupt("Initial Job Arrival")

        with server.request(priority=0) as req:
            yield req
            if not initial_queue:
                return
            next_job = initial_queue.pop(0)
            log_event(env.now, "Serve Initial", next_job[0], next_job[1], "--", next_job[0], "--")
            job_size = next_job[1]
            run_for_max_L = min(L, job_size)
            yield env.timeout(run_for_max_L)
            if (job_size < L):
                log_event(env.now, "Initial Short Job Done", next_job[0], "--", "--", "DONE", "--")

                response_times.append(env.now - next_job[0])
                response_pshort_times.append(env.now - next_job[0])
            else:
                log_event(env.now, "Initial L Long Job Done", next_job[0], "--", "--", "DONE", "--")
                predicted_service_time = predict_service_time(job_size, expensive_alpha)
                expensive_p_queue.append((next_job[0], predicted_service_time, L, job_size)) #(ID, prediction, age, size)
                log_event(env.now, "Append to Expensive", next_job[0], "--", predicted_service_time, "--", "--")
            check_schedule(env, server)
            current_job = None
            start_time = None


    elif type=="expensive_p_queue":
        current_process = job_processes.get(start_time)
        if current_process and current_process.is_alive:
            current_process.interrupt("Expensive P Arrival")

        with server.request(priority=1) as req:
            yield req
            if not expensive_p_queue:
                return
            next_job = expensive_p_queue.pop(0)
            log_event(env.now, "Serve Expensive p", next_job[0],  "--", "--", next_job[0], f"Age: {next_job[2]}")
            current_job = next_job
            remaining_expensivep = c2 - next_job[2]
            try:
                yield env.timeout(remaining_expensivep)

                log_event(env.now, "Expensive p Done", next_job[0], "--", "--", "DONE", "--")
                n_expensive_p += 1
                long_queue.append((next_job[0], next_job[1], next_job[2] - c2, next_job[3]))
                check_schedule(env, server)
                current_job = None
                start_time = None

            except simpy.Interrupt as interrupt:
                if str(interrupt.cause) == "Initial Job Arrival":
                    elapsed_time = env.now - start_time
                    updated_job = (current_job[0], current_job[1], current_job[2] + elapsed_time, current_job[3])
                    log_event(env.now, "Initial Preempted Expensive_p", job_id=current_job[0])
                    expensive_p_queue.append(updated_job)
                    current_job = None
                    start_time = None



    if type == "long_queue":

        long_queue.sort(key=remaining_time)

        if initial_queue:
            check_schedule(env, server)
            return

        if current_job:
            age_received = env.now - start_time
            if (current_job[1] - (current_job[2] + age_received)) > remaining_time(long_queue[0]): #remaining time current > remaining time of new
                updated_job = (current_job[0], current_job[1], current_job[2] + age_received, current_job[3])
                log_event(env.now, "Long Preempted Long", updated_job[0], updated_job[1], "--", updated_job[0], f"Age: {updated_job[2]}")

                long_queue.append(updated_job)
                long_queue.sort(key=remaining_time)
                current_process = job_processes.get(start_time)
                if current_process:
                    current_process.interrupt()


        with server.request(priority=3) as req:
            yield req

            if initial_queue:
                check_schedule(env, server)
                return

            if not long_queue:
                return
            next_job = long_queue.pop(0)
            log_event(env.now, "Serve Long", next_job[0], next_job[1], "--", next_job[0], f"Age: {next_job[2]}")
            start_time = env.now
            current_job = (next_job[0], next_job[1], next_job[2], next_job[3])
            current_process = env.process(serve_long_job(env, next_job))
            job_processes[start_time] = current_process
            try:
                yield current_process

                response_time = env.now - current_job[0]
                log_event(env.now, "Long Job Done", current_job[0], "--", "--", "DONE", f"L-response:{response_time}")

                job_details = {
                    'response_time': response_time,
                    'actual_size': current_job[3],
                    'predicted_size': current_job[1]
                }
                #response_plong_times_details.append(job_details)
                response_times.append(response_time)
                response_plong_times.append(response_time)
                current_job = None
                if start_time in job_processes:
                    del job_processes[start_time]
                start_time = None

            except simpy.Interrupt as interrupt:
                if str(interrupt.cause) == "Initial Job Arrival":
                    elapsed_time = env.now - start_time
                    updated_job = (current_job[0], current_job[1], current_job[2] + elapsed_time, current_job[3])
                    log_event(env.now, "Initial Preempted Long", updated_job[0], updated_job[1], "--", updated_job[0], f"Age: {updated_job[2]}")
                    if start_time in job_processes:
                        del job_processes[start_time]

                    long_queue.append(updated_job)
                    long_queue.sort(key=remaining_time)
                    current_job = None
                    check_schedule(env, server)

                if str(interrupt.cause) == "Expensive P Arrival":
                    elapsed_time = env.now - start_time
                    updated_job = (current_job[0], current_job[1], current_job[2] + elapsed_time, current_job[3])
                    log_event(env.now, "ExpensiveP Preempted Long", updated_job[0], updated_job[1], "--", updated_job[0], f"Age: {updated_job[2]}")
                    if start_time in job_processes:
                        del job_processes[start_time]

                    long_queue.append(updated_job)
                    long_queue.sort(key=remaining_time)
                    current_job = None
                    start_time = None
                    check_schedule(env, server)

    current_job = None
    check_schedule(env, server)


def serve_long_job(env, job):
    actual_remaining_time = job[3] - job[2]
    yield env.timeout(actual_remaining_time)


def log_job_completion(env, job):
    log_event(env.now, "Long Job Done", job_id=job[0])
    response_time = env.now - job[0]
    job_details = {
        'response_time': response_time,
        'actual_size': job[3],
        'predicted_size': job[1]
    }
    #response_plong_times_details.append(job_details)
    response_times.append(response_time)
    response_plong_times.append(response_time)
    global current_job, start_time
    current_job = None
    if start_time in job_processes:
        del job_processes[start_time]
    start_time = None


def job_generator(env, server, arriving_rate):
    global initial_queue, expensive_p_queue, long_queue
    while True:
        yield env.timeout(random.expovariate(arriving_rate))
        job = env.now
        job_size = job_size_distribution()

        initial_queue.append((job, job_size)) #(job, job_size) --
        log_event(env.now, "Append to Initial queue", job, job_size, "--", "--", "--")
        check_schedule(env, server)




def run_simulation(arrival_rate, limit, c2_given_price, test_expen_alpha):
    global response_times, response_pshort_times, response_plong_times, response_plong_times_details
    global start_time, job_processes, SIMULATION_TIME
    global L
    global n_cheap_p, n_expensive_p
    global initial_queue, expensive_p_queue, long_queue
    global current_job
    global cheap_alpha, expensive_alpha
    global c2

    # Re-initialize the lists at the start of each run
    response_times = []
    response_pshort_times = []
    response_plong_times = []
    response_plong_times_details = []
    initial_queue = []
    long_queue = []
    expensive_p_queue = []

    L = limit
    n_cheap_p = 0
    n_expensive_p = 0
    c2 = c2_given_price
    expensive_alpha = test_expen_alpha

    start_time = None
    job_processes = {}


    print(f'Simulating DelayPredict M/M/1 queue with external cost, lambda: {arrival_rate}, L:{L}')
    env = simpy.Environment()
    server = simpy.PriorityResource(env, capacity=1)
    current_job = None
    print(f"{'TIME':<10} | {'EVENT':<18} | {'JOB ID':<10} | {'SIZE':<8} | {'PREDICTED SIZE':<15} | {'SERVER':<18} | {'NOTES'}")
    print("-" * 145)
    env.process(job_generator(env, server, arrival_rate))

    env.run(until=SIMULATION_TIME)
    #print(f'response_times: {response_times}')
    mean_response_time_pshort = 0
    mean_response_time_plong = 0
    print(f'n_cheap_p: {n_cheap_p}')
    print(f'n_expensive_p: {n_expensive_p}')


    mean_response_time_withoutcost = sum(response_times)/ len(response_times)
    mean_response_time_pshort_withoutcost = sum(response_pshort_times)/ len(response_pshort_times)
    mean_response_time_plong_withoutcost = sum(response_plong_times)/ len(response_plong_times)

    print(f"\nMean Response Time without costs: {mean_response_time_withoutcost:.2f}, short: {mean_response_time_pshort_withoutcost:.2f}, long:{mean_response_time_plong_withoutcost:.2f}")

    print(f"\nMean Response Time without costs: {(sum(response_times)/ len(response_times)):.2f} time units")
    mean_response_time = (sum(response_times) + n_expensive_p * c2) / len(response_times)
    print(f"\nMean Response Time with costs: {mean_response_time:.2f} time units lambda: {arrival_rate}, L:{L}")

    print(f"\nPredicted short: Mean Response Time without costs: {(sum(response_pshort_times)/ len(response_pshort_times)):.2f} time units")
    mean_response_time_pshort = (sum(response_pshort_times)) / len(response_pshort_times)
    print(f"\nPredicted short: Mean Response Time with costs: {mean_response_time_pshort:.2f} time units lambda: {arrival_rate}, L:{L}")

    print(f"\nPredicted long: Mean Response Time without costs: {(sum(response_plong_times)/ len(response_plong_times)):.2f} time units")
    mean_response_time_plong = (sum(response_plong_times) + n_expensive_p * c2) / len(response_plong_times)
    print(f"\nPredicted long: Mean Response Time with costs: {mean_response_time_plong:.2f} time units lambda: {arrival_rate}, L:{L}")

    # Original mean response time calculation (including the entire list)
    mean_response_time_full_list = sum(response_plong_times) / len(response_plong_times)
    print(f"\nPredicted long: Mean Response Time for full list: {mean_response_time_full_list:.2f} time units")

    return mean_response_time, mean_response_time_pshort, mean_response_time_plong, response_plong_times_details


def simulation_wrapper(arrival_rate, limit, c2_given_price, test_expen_alpha):
    global response_times, response_pshort_times, response_plong_times, initial_queue, expensive_p_queue, long_queue, start_time, job_processes, n_expensive_p
    global c2
    response_times = []
    response_pshort_times = []
    response_plong_times = []
    initial_queue = []
    long_queue = []
    expensive_p_queue = []

    # Initialize lists to store results
    mean_response_times = []
    mean_response_times_pshort = []
    mean_response_times_plong = []

    start_time = None
    job_processes = {}
    n_expensive_p = 0
    c2 = c2_given_price

    # Run the simulation 100 times
    for _ in range(100):
        mean_response_time, mean_response_time_pshort, mean_response_time_plong , response_plong_times_details = run_simulation(arrival_rate, limit, c2, test_expen_alpha)
        mean_response_times.append(mean_response_time)
        mean_response_times_pshort.append(mean_response_time_pshort)
        mean_response_times_plong.append(mean_response_time_plong)

    # Calculate and print the average of each
    average_mean_response_time = sum(mean_response_times) / len(mean_response_times)
    average_mean_response_time_pshort = sum(mean_response_times_pshort) / len(mean_response_times_pshort)
    average_mean_response_time_plong = sum(mean_response_times_plong) / len(mean_response_times_plong)

    print(f"Average Mean Response Time: {average_mean_response_time:.2f} time units")
    print(f"Average Predicted Short Mean Response Time: {average_mean_response_time_pshort:.2f} time units")
    print(f"Average Predicted Long Mean Response Time: {average_mean_response_time_plong:.2f} time units")


    current_date = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M")

    #filename = f'res/simulation_res_{current_date}.csv'

    #with open(filename, 'w') as file:
    #    file.write(f"arrival_rate:{arrival_rate}, L:{L}, c1:{c1}, c2:{c2}\n")
    #    file.write(f"mean_response_time:{mean_response_time}, mean_response_time_pshort:{mean_response_time_pshort}, mean_response_time_plong:{mean_response_time_plong}\n")
    #
    #
    # with open(f'long_res/long_job_response_{predictor}_times_{arrival_rate}_T_{threshold}_{current_date}.csv', 'w') as file:
    #     # Writing the headers
    #     file.write("Response Time,Actual Size,Predicted Size\n")
    #
    #     # Writing the job details
    #     for job_detail in response_plong_times_details:
    #         line = f"{job_detail['response_time']},{job_detail['actual_size']},{job_detail['predicted_size']}\n"
    #         file.write(line)

    return average_mean_response_time, average_mean_response_time_pshort, average_mean_response_time_plong

#####
markers = ['o', 'x', '^', 's', 'D', 'p']  # You can add more markers as needed
colors = ['b', 'g', 'r', 'c', 'm', 'y']  # Basic color abbreviations: b-blue, g-green, r-red, etc.


def test_cost_vs_L():
    L_test_values = [0.1, 0.5, 1, 1.5, 2, 4, 5, 8]
    default_arrival_rate = 0.9
    fixed_c2 = 2  # TODO
    default_cheap_alpha = 0.8
    default_expen_alpha = 0.01

    label = 'DelayPredict'

    # Running the test
    results = []
    for l_test_value in L_test_values:
        average_mean_response_time, average_mean_response_time_pshort, average_mean_response_time_plong = simulation_wrapper(default_arrival_rate, l_test_value, fixed_c2, default_expen_alpha)

        results.append({
            'Alg': label,
            'L Value': l_test_value,
            'Average Mean Response Time': average_mean_response_time,
            'Average PShort Response Time': average_mean_response_time_pshort,
            'Average PLong Response Time': average_mean_response_time_plong,
            'Default arrival rate': default_arrival_rate,
            'Default c2': fixed_c2,
            'Default calpha': default_cheap_alpha,
            'Default expalpha': default_expen_alpha
        })

    # Saving the results to a CSV file
    results_df = pd.DataFrame(results)
    results_df.to_csv(f'res/delayP_cost_vs_L_results_{predictor}_alpha_{cheap_alpha}_{expensive_alpha}_dist_{dist}_ext.csv', index=False)

    if PLOT_GRAPHS:
        # Plotting the results
        for (t, marker, color) in zip(L_test_values, markers, colors):
            label_t = label
            df_subset = results_df[results_df['Alg'] == label_t]

            plt.plot(df_subset['T Value'], df_subset['Average Mean Response Time'], marker=marker, color=color,
                     label=label_t)

        plt.xlabel(r'$L$')
        plt.ylabel('Cost')
        plt.legend()
        plt.grid(True)
        plt.savefig(f'graphs/delayP_cost_vs_L_{predictor}_alpha_{cheap_alpha}_{expensive_alpha}_ext.png')
        plt.clf()


def test_cost_vs_arrivalrate():
    # Test parameters
    arrival_rate_values = [0.5, 0.6, 0.7, 0.9, 0.95]
    L_value = 1
    fixed_c2 = 2
    default_expen_alpha = 0.01

    label = 'DelayPredict'

    # Running the test
    results = []
    for arrival_rate in arrival_rate_values:
        avg_mean, avg_pshort, avg_plong = simulation_wrapper(arrival_rate, L_value, fixed_c2, default_expen_alpha)
        results.append({
            'Alg': label,
            'Arrival Rate': arrival_rate,
            'Average Mean Response Time': avg_mean,
            'Average PShort Response Time': avg_pshort,
            'Average PLong Response Time': avg_plong,
            'Default L': L_value,
            'Default c2': fixed_c2,
            'Default expalpha': default_expen_alpha
        })


    # Saving the results to a CSV file
    results_df = pd.DataFrame(results)
    results_df.to_csv(f'res/delayP_cost_vs_arrivalrate_results_{predictor}_alpha_{cheap_alpha}_{expensive_alpha}_dist_{dist}_ext.csv', index=False)

    if PLOT_GRAPHS:
        for (marker, color) in zip(markers, colors):
            label_t = label
            df_subset = results_df[results_df['Alg'] == label_t]

            plt.plot(df_subset['Arrival Rate'], df_subset['Average Mean Response Time'], marker=marker, color=color,
                     label=label_t)

        plt.xlabel('Arrival Rate')
        plt.ylabel('Cost')
        plt.legend()
        plt.grid(True)
        plt.savefig(f'graphs/delayP_cost_vs_arrivalrate_{predictor}_alpha_{cheap_alpha}_{expensive_alpha}_ext.png')
        plt.clf()


import matplotlib.pyplot as plt
import pandas as pd


if __name__ == "__main__":
    #TESTS
    test_cost_vs_L()
    test_cost_vs_arrivalrate()


    #simulation_wrapper(0.5, 1, 0, 0)


