import datetime
import math
import simpy
import random
import numpy as np
import sys
import matplotlib
matplotlib.use('TkAgg')  # Adjust based on your OS and environment
import matplotlib.pyplot as plt
import pandas as pd
from real_dataset.parse_datasets import load_and_prepare_data

response_times = []
response_pshort_times = []
response_plong_times = []

cheap_p_queue = []
short_queue = []
long_queue = []
response_plong_times_details = []

start_time = None
job_processes = {}
n_cheap_p = 0
n_expensive_p = 0
SIMULATION_TIME = 1000000
#T= float('inf') #to test FCFS
#T= 0 #to test SPRPT
#T = 4
current_job = None
LOG_EVENT_PRINT = 0
PLOT_GRAPHS = 1
c1 = 0.00001 #TODO: change

#predictor = 'expP'
#predictor = 'perfectP'
#predictor = 'uniP'
cheap_alpha = 0.8

#dist = 'weibull'
#dist = 'exponential'

#Real dataset
dist = 'real'
predictor = 'real'

real_data_index = 0

class EndOfDataException(Exception):
    pass

def log_event(time, event, job_id="--", size="--", predicted_size="--", notes="--", queue_content="--"):
    if LOG_EVENT_PRINT:
        time_str = f"{time:.5f}"
        job_id = "--" if job_id == "--" else f"{job_id:.2f}"
        size = "--" if size == "--" else f"{size:.2f}"
        predicted_size = "--" if predicted_size == "--" else f"{predicted_size:.2f}"
        notes = "--" if notes == "--" else "DONE" if notes == "DONE" else f"{notes:.2f}"
        print(f"{time_str:<10}| {event:<20}| {job_id:<10}| {size:<10}| {predicted_size:<15}| {notes:<20}| {queue_content}")



def short_job_size_distribution(): #TODO: change
    global T
    while True:
        sample = job_size_distribution()
        if sample < T:
            return sample

def job_size_distribution():
    global real_data_index

    if dist == 'exponential':
        return random.expovariate(1), 0

    if dist == 'weibull':
        U = random.random()
        return (-math.log(1 - U))**2 / 2, 0

    elif dist == 'real':
        # Return job size from real dataset, and handle index increment
        if real_data_index < len(real_data): #TODO
            job_size = real_data.loc[real_data_index, 'normalized_runtime']
            predicted_job_size = real_data.loc[real_data_index, 'normalized_predicted_runtime']

            real_data_index += 1
            return job_size, predicted_job_size
        else:
            raise EndOfDataException("End of data reached")

    else:
        raise ValueError("Unknown distribution type specified.")

def predict_service_time(job, uni_alpha):
    if predictor == 'perfectP' or uni_alpha == 0:
        return job

    if predictor == 'expP':
        return random.expovariate(1 / job)

    if predictor == 'uniP':
        lower_bound = (1 - uni_alpha) * job
        upper_bound = (1 + uni_alpha) * job
        return random.uniform(lower_bound, upper_bound)

############### Different types of predictors  #################

#def predict_service_time(job): #TODO: perfect predictor
#    return job

#def predict_service_time(job): #TODO uniP predictor
#    lower_bound = (1 - alpha) * job
#    upper_bound = (1 + alpha) * job
#    return random.uniform(lower_bound, upper_bound)

#def predict_service_time(z): #TODO: exponential predictor
#     return random.expovariate(1/z)

#def predict_uniP_cheap(job): #TODO cheap uniP predictor
#    global cheap_alpha
#    lower_bound = (1 - cheap_alpha) * job
#    upper_bound = (1 + cheap_alpha) * job
#    return random.uniform(lower_bound, upper_bound)

#def predict_uniP_expensive(job):  # TODO expensive uniP predictor
#    global expensive_alpha
#    lower_bound = (1 - expensive_alpha) * job
#    upper_bound = (1 + expensive_alpha) * job
#    return random.uniform(lower_bound, upper_bound)
#####################################################


def check_schedule(env, server):
    global current_job, current_cheap_job, current_expensive_job, start_time, start_cheap_time, start_expensive_time
    global cheap_p_queue, expensive_p_queue, short_queue, long_queue

    cheap_p_str = [f"({x:.2f}, {y:.2f})" for x, y in cheap_p_queue]
    short_queue_str = [f"({x:.2f}, {y:.2f})" for x, y in short_queue]
    long_queue_str = [f"({x:.2f}, {y:.2f}, {z:.2f})" for x, y, z in long_queue]
    queue_str = f"Cheap_p: {cheap_p_str}, Short: {short_queue_str}, Long: {long_queue_str}"
    log_event(env.now, "Checking Policy", "--", "--", "--", "--", queue_str)

    if short_queue:
        log_event(env.now, "Schedule Short Job", "--", "--", "--", "--", "--")
        env.process(serve_job(env, "short_queue", server))
    elif cheap_p_queue:
        log_event(env.now, "Schedule Cheap p", "--", "--", "--", "--", "--")
        env.process(serve_job(env, "cheap_p_queue", server))
    elif long_queue:
        log_event(env.now, "Schedule long Job", "--", "--", "--", "--", "--")
        env.process(serve_job(env, "long_queue", server))


def serve_job(env, type, server):
    global current_job, current_cheap_job, current_expensive_job, start_cheap_time, start_expensive_time, start_time, n_cheap_p, n_expensive_p
    global cheap_p_queue, expensive_p_queue, short_queue, long_queue

    if type=="short_queue":
        current_process = job_processes.get(start_time)
        if current_process and current_process.is_alive:
            current_process.interrupt("Short Job Arrival")

        with server.request(priority=0) as req:
            yield req
            if not short_queue:
                return
            next_job = short_queue.pop(0)
            log_event(env.now, "Serve Short", next_job[0], next_job[1], "--", next_job[0], "--")
            yield env.timeout(next_job[1])
            log_event(env.now, "Short Job Done", next_job[0], "--", "--", "DONE", f"S-response:{env.now - next_job[0]}")
            response_times.append(env.now - next_job[0])
            response_pshort_times.append(env.now - next_job[0])
        check_schedule(env, server)

    elif type=="cheap_p_queue":
        current_process = job_processes.get(start_time)
        if current_process and current_process.is_alive:
            current_process.interrupt("Cheap P Arrival")

        with server.request(priority=1) as req:
            yield req
            if not cheap_p_queue:
                return
            next_job = cheap_p_queue.pop(0)
            log_event(env.now, "Serve Cheap p", next_job[0],  "--", "--", next_job[0], f"Age: {next_job[1]}")
            current_job = next_job
            start_time = env.now
            remaining_cheapp = c1 - next_job[1]
            try:
                yield env.timeout(remaining_cheapp)
                log_event(env.now, "Cheap p Done", next_job[0], "--", "--", "DONE", "--")

                try:
                    job_size, cheap_predicted_service_time = job_size_distribution()
                except EndOfDataException:
                    print("End of data reached. Terminating simulation.")
                    return

                if dist == 'exponential' or dist == 'weibull':
                    cheap_predicted_service_time = predict_service_time(job_size, cheap_alpha)

                n_cheap_p += 1

                if cheap_predicted_service_time < T:
                    short_queue.append((next_job[0], job_size))
                    log_event(env.now, "Append to Short", next_job[0], "--", "--", "--", "--")

                else:
                    n_expensive_p += 1
                    long_queue.append((next_job[0], 0, job_size))
                    log_event(env.now, "Append to Long", next_job[0], "--", "--", "--", "--")

                check_schedule(env, server)
                current_job = None
                start_time = None

            except simpy.Interrupt as interrupt:
                if str(interrupt.cause) == "Short Job Arrival":
                    elapsed_time = env.now - start_time
                    updated_job = (current_job[0], current_job[1] + elapsed_time)
                    log_event(env.now, "Short Preempted cheap_p", job_id=current_job[0])
                    if start_time in job_processes:
                        del job_processes[start_time]
                    cheap_p_queue.append(updated_job)
                    current_job = None
                    start_time = None

    elif type=="long_queue":
        with server.request(priority=2) as req:
            yield req
            if not long_queue:
                return
            next_job = long_queue.pop(0)
            log_event(env.now, "Serve Long", next_job[0], next_job[2], "--", next_job[0], f"Age: {next_job[1]}")
            start_time = env.now
            current_job = (next_job[0], next_job[1], next_job[2])
            current_process = env.process(serve_long_job(env, next_job))
            job_processes[start_time] = current_process
            try:
                yield current_process

                log_event(env.now, "Long Job Done", job_id=next_job[0])
                response_time = env.now - next_job[0]
                job_details = {
                    'response_time': response_time,
                    'actual_size': next_job[1]
                }
                response_plong_times_details.append(job_details)
                response_times.append(response_time)
                response_plong_times.append(response_time)
                current_job = None
                del job_processes[start_time]
                start_time = None

            except simpy.Interrupt as interrupt:
                if str(interrupt.cause) == "Short Job Arrival":
                    log_event(env.now, "Short Preempted Long", job_id=current_job[0])
                    elapsed_time = env.now - start_time
                    updated_job = (current_job[0], current_job[1] + elapsed_time, current_job[2])
                    long_queue.append(updated_job)
                    check_schedule(env, server)

                if str(interrupt.cause) == "Cheap P Arrival":
                    elapsed_time = env.now - start_time
                    updated_job = (current_job[0], current_job[1] + elapsed_time, current_job[2])
                    log_event(env.now, "Cheap Preempted Long", updated_job[0], updated_job[1], "--", updated_job[0], f"Age: {updated_job[2]}")
                    long_queue.append(updated_job)
                    check_schedule(env, server)

        check_schedule(env, server)


def serve_long_job(env, job):
    actual_remaining_time = job[2] - job[1] # size - age
    yield env.timeout(actual_remaining_time)


def job_generator(env, server, arrival_rate):
    global n_cheap_p
    global cheap_p_queue, short_queue, long_queue

    while True:
        yield env.timeout(random.expovariate(arrival_rate))
        job = env.now

        cheap_p_queue.append((job, 0)) #(job, age) -- age for cheap prediction since it could preempt
        log_event(env.now, "Append to Cheap p", job, "--", "--", "--", "--")
        check_schedule(env, server)



def run_simulation(arrival_rate, threshold, c1_price):
    global response_times, response_pshort_times, response_plong_times, response_plong_times_details
    global start_time, job_processes, SIMULATION_TIME
    global T
    global n_cheap_p, n_expensive_p
    global short_queue, long_queue
    global current_job
    global c1

    # Re-initialize the lists at the start of each run
    response_times = []
    response_pshort_times = []
    response_plong_times = []
    response_plong_times_details = []
    short_queue = []
    long_queue = []

    T = threshold
    n_cheap_p = 0
    n_expensive_p = 0
    c1 = c1_price

    start_time = None
    job_processes = {}
    SIMULATION_TIME = 1000000


    if threshold == 0: #SPRPT
        c1 = 0
    if threshold == float('inf'): #FCFS
        c1 = 0

    print(f'Simulating M/M/1 queue with small advice, FCFS of short/long jobs, server cost model, lambda: {arrival_rate}, T:{T}, c1:{c1}')
    env = simpy.Environment()
    server = simpy.PriorityResource(env, capacity=1)
    current_job = None
    print(f"{'TIME':<10} | {'EVENT':<18} | {'JOB ID':<10} | {'SIZE':<8} | {'PREDICTED SIZE':<15} | {'SERVER':<18} | {'NOTES'}")
    print("-" * 145)
    env.process(job_generator(env, server, arrival_rate))
    #env.process(dummy_job_generator(env, server))


    env.run(until=SIMULATION_TIME)
    #print(f'response_times: {response_times}')
    mean_response_time_pshort = 0
    mean_response_time_plong = 0
    print(f'n_cheap_p: {n_cheap_p}')
    print(f'n_expensive_p: {n_expensive_p}')

    mean_response_time = sum(response_times) / len(response_times)
    print(f"\nMean Response Time with costs: {mean_response_time:.2f} time units lambda: {arrival_rate}, T:{T}")
    mean_response_time_pshort = (sum(response_pshort_times)) / len(response_pshort_times)
    print(
        f"\nPredicted short: Mean Response Time with costs: {mean_response_time_pshort:.2f} time units lambda: {arrival_rate}, T:{T}")

    mean_response_time_plong = (sum(response_plong_times)) / len(response_plong_times)
    print(
        f"\nPredicted long: Mean Response Time with costs: {mean_response_time_plong:.2f} time units lambda: {arrival_rate}, T:{T}")

    return mean_response_time, mean_response_time_pshort, mean_response_time_plong, response_plong_times_details


def simulation_wrapper(arrival_rate, threshold, c1):
    global cheap_p_queue, short_queue, long_queue, response_times, response_pshort_times, response_plong_times, start_time, job_processes, n_cheap_p, n_expensive_p

    start_time = None
    job_processes = {}
    n_cheap_p = 0
    n_expensive_p = 0

    response_times = []
    response_pshort_times = []
    response_plong_times = []
    cheap_p_queue = []
    short_queue = []
    long_queue = []
    response_plong_times_details = []


    # Initialize lists to store results
    mean_response_times = []
    mean_response_times_pshort = []
    mean_response_times_plong = []

    # Run the simulation 100 times
    for _ in range(100):#TODO change to 100
        mean_response_time, mean_response_time_pshort, mean_response_time_plong , response_plong_times_details = run_simulation(arrival_rate, threshold, c1)
        mean_response_times.append(mean_response_time)
        mean_response_times_pshort.append(mean_response_time_pshort)
        mean_response_times_plong.append(mean_response_time_plong)

    # Calculate and print the average of each
    average_mean_response_time = sum(mean_response_times) / len(mean_response_times)
    average_mean_response_time_pshort = sum(mean_response_times_pshort) / len(mean_response_times_pshort)
    average_mean_response_time_plong = sum(mean_response_times_plong) / len(mean_response_times_plong)

    print(f"Average Mean Response Time: {average_mean_response_time:.2f} time units")
    print(f"Average Predicted Short Mean Response Time: {average_mean_response_time_pshort:.2f} time units")
    print(f"Average Predicted Long Mean Response Time: {average_mean_response_time_plong:.2f} time units")


    current_date = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M")

    #filename = f'res/1bit_simulation_res_{current_date}.csv'
    #with open(filename, 'w') as file:
    #    file.write(f"arrival_rate:{arrival_rate}, T:{T}, c1:{c1}\n")
    #    file.write(f"mean_response_time:{mean_response_time}, mean_response_time_pshort:{mean_response_time_pshort}, mean_response_time_plong:{mean_response_time_plong}\n")


    with open(f'long_res/1bit_long_job_response_{predictor}_times_{arrival_rate}_T_{threshold}_{current_date}.csv', 'w') as file:
        # Writing the headers
        file.write("Response Time,Actual Size\n")

        # Writing the job details
        for job_detail in response_plong_times_details:
            line = f"{job_detail['response_time']},{job_detail['actual_size']}\n"
            file.write(line)

    return average_mean_response_time, average_mean_response_time_pshort, average_mean_response_time_plong

#####
markers = ['o', 'x', '^', 's', 'D', 'p']  # You can add more markers as needed
colors = ['b', 'g', 'r', 'c', 'm', 'y']  # Basic color abbreviations: b-blue, g-green, r-red, etc.


def test_cost_vs_T():
    global real_data_index
    T_values = [1]
    T_test_values = [0.1, 0.5, 1, 1.5, 2, 3, 3.5, 4, 5, 8]
    default_arrival_rate = 0.9
    fixed_c1 = 0.01
    labels = {1: '1bit'}


    if dist == 'real':
        T_values = [4]
        labels = {4: '1bit'}

    # Running the test
    results = []
    for t in T_values:
        for test_value in T_test_values:
            real_data_index = 0
            average_mean_response_time, average_mean_response_time_pshort, average_mean_response_time_plong = simulation_wrapper(default_arrival_rate, test_value, fixed_c1)
            results.append({
                'Alg': labels[t],
                'T Value': test_value,
                'Average Mean Response Time': average_mean_response_time,
                'Average PShort Response Time': average_mean_response_time_pshort,
                'Average PLong Response Time': average_mean_response_time_plong,
                'Default arrival rate': default_arrival_rate,
                'Default c1': fixed_c1
            })

    # Saving the results to a CSV file
    results_df = pd.DataFrame(results)
    results_df.to_csv(f'res/servercost_cost_vs_T_results_{predictor}_alpha_{cheap_alpha}_dist_{dist}_dataset_{dataset}_1bit_new.csv', index=False)

    if PLOT_GRAPHS:
        # Plotting the results
        for (t, marker, color) in zip(T_values, markers, colors):
            label_t = labels[t]
            df_subset = results_df[results_df['Alg'] == label_t]

            plt.plot(df_subset['T Value'], df_subset['Average Mean Response Time'], marker=marker, color=color,
                     label=label_t)

        plt.xlabel(r'$T$')
        plt.ylabel('Cost')
        plt.legend()
        plt.grid(True)
        plt.savefig(f'graphs/servercost_cost_vs_T_{predictor}_alpha_{cheap_alpha}_1bit.png')
        plt.clf()


def test_cost_vs_arrivalrate():
    global real_data_index
    # Test parameters
    arrival_rate_values = [0.5, 0.6, 0.7, 0.9, 0.95]
    T_values = [1]
    fixed_c1 = 0.01

    labels = {1: '1bit'}

    if dist == 'real':
        T_values = [4]
        labels = {4: '1bit'}

    # Running the test
    results = []
    for t in T_values:
        for arrival_rate in arrival_rate_values:
            real_data_index = 0
            avg_mean, avg_pshort, avg_plong = simulation_wrapper(arrival_rate, t, fixed_c1)
            results.append({
                'Alg': labels[t],
                'Arrival Rate': arrival_rate,
                'Average Mean Response Time': avg_mean,
                'Average PShort Response Time': avg_pshort,
                'Average PLong Response Time': avg_plong,
                'Default T': t,
                'Default c1': fixed_c1,
                'Default c2': 0
            })


    # Saving the results to a CSV file
    results_df = pd.DataFrame(results)
    results_df.to_csv(f'res/servercost_cost_vs_arrivalrate_results_{predictor}_alpha_{cheap_alpha}_dist_{dist}_dataset_{dataset}_1bit_new.csv', index=False)

    if PLOT_GRAPHS:
        for (t, marker, color) in zip(T_values, markers, colors):
            label_t = labels[t]
            df_subset = results_df[results_df['Alg'] == label_t]

            plt.plot(df_subset['Arrival Rate'], df_subset['Average Mean Response Time'], marker=marker, color=color,
                     label=label_t)

        plt.xlabel('Arrival Rate')
        plt.ylabel('Cost')
        plt.legend()
        plt.grid(True)
        plt.savefig(f'graphs/servercost_cost_vs_arrivalrate_{predictor}_alpha_{cheap_alpha}_1bit.png')
        plt.clf()

if __name__ == "__main__":

    datasets = ['twosigma', 'google', 'trinity']

    for dataset in datasets:
        if dataset == 'twosigma':
            file_path = 'real_dataset/jvupredict_twosigma.csv.gz'
        if dataset == 'google':
            file_path = 'real_dataset/jvupredict_google_all_features.csv.gz'
        if dataset == 'trinity':
            file_path = 'real_dataset/jvupredict_trinity.csv.gz'

        real_data = load_and_prepare_data(file_path)


        #TESTS
        test_cost_vs_T()
        test_cost_vs_arrivalrate()

