import math
import numpy as np
import copy
import matplotlib.pyplot as plt
from matplotlib import rc
import pickle as pkl

# Set font to match LaTeX ptm (Times)
rc('font', **{'family': 'serif', 'serif': ['Times'], 'size':18})
rc('text', usetex=True)

class Job:
    def __init__(self, processing_time):
        self.processing_time = processing_time
        self.signals = []
        self.remaining_processing_time = processing_time

    def __repr__(self):
        return f"Job(processing_time={self.processing_time}, signals={self.signals}, remaining_processing_time={self.remaining_processing_time})"

    def num_signals_emitted(self):
        processed_time = self.processing_time - self.remaining_processing_time
        return sum(1 for signal in self.signals if signal <= processed_time)

    def job_completed(self):
        return self.remaining_processing_time == 0

    def sample_signals(self, g):
        rate = g / self.processing_time
        self.signals = simulate_poisson_point_process_1d(rate, self.processing_time)

    def process(self):
        if self.remaining_processing_time >= 1:
            self.remaining_processing_time -= 1
        else:
            raise ValueError(
                "Processing time {self.remaining_processing_time} cannot be less than 1."
            )


def simulate_poisson_point_process_1d(rate, length):
    """
    Simulates a Poisson Point Process in a 1D interval.

    Parameters:
        rate (float): Intensity (λ), average number of points per unit length.
        length (float): Length of the interval.

    Returns:
        points (list of floats): List of point locations in the interval.
    """
    # Total expected number of points
    expected_points = rate * length

    # Sample the number of points from a Poisson distribution
    num_points = np.random.poisson(expected_points)

    # Generate random point locations uniformly in the interval
    points = np.random.randint(0, length + 1, num_points)
    return sorted(points)



def process_jobs(jobs):
    completed_jobs = []
    remaining_jobs = []

    for job in jobs:

        job.process()
        if job.job_completed():
            completed_jobs.append(job)
        else:
            remaining_jobs.append(job)

    return completed_jobs, remaining_jobs


def etc(jobs, threshold, g):
    t = 0
    obj = 0
    jobs = copy.deepcopy(jobs)
    
    n = len(jobs)


    # Exploration
    while all([job.num_signals_emitted() <= threshold for job in jobs]):   
        delta = len(jobs)
        completed_jobs, remaining_jobs = process_jobs(jobs)
        jobs = remaining_jobs
        t += delta
        obj += t * len(completed_jobs)

    # Sort jobs by the number of signals emitted in decreasing order
    jobs.sort(key=lambda job: job.num_signals_emitted(), reverse=True)

    # Exploitation
    for job in jobs:
        t += job.remaining_processing_time
        obj += t
        

    return obj

def algorithm1(jobs, threshold):
    t = 0
    obj = 0
    jobs = copy.deepcopy(jobs)

    while len(jobs) > 0:    
        for job in jobs:
            if job.num_signals_emitted() >= threshold:
                jobs.remove(job)
                t += job.remaining_processing_time
                obj += t

        delta = len(jobs)
        completed_jobs, remaining_jobs = process_jobs(jobs)
        jobs = remaining_jobs
        t += delta
        obj += t * len(completed_jobs)

    return obj

def rr(jobs):
    t = 0
    obj = 0
    jobs = copy.deepcopy(jobs)

    while len(jobs) > 0:    
        delta = len(jobs)
        completed_jobs, remaining_jobs = process_jobs(jobs)
        jobs = remaining_jobs
        t += delta
        obj += t * len(completed_jobs)

    return obj


def opt_val(jobs):
    lengths = sorted([j.processing_time for j in jobs], reverse=True)
    return sum([(i + 1) * p for i, p in enumerate(lengths)])


def threshold_opt(g):
    return min(g, math.ceil((g / 2.0) ** (2 / 3)) + 1)

def threshold_sqrt(g):
    return min(g, math.ceil(g ** (1/2)) + 1)

def sample_signals(jobs, g):
    for job in jobs:
        job.sample_signals(g)


def create_jobs(n): 
    scale = 50
    shape = 1.1

    processing_times = np.random.pareto(shape, n) 
    processing_times = processing_times * scale
    return [Job(max(1,int(p))) for p in processing_times]


def create_data(num_iter, name):
    g_values = [(i+1.3**i) for i in range(40) if (i+1.3**i) <= 2048] #min(max_processing_time, 3000)]

    instances = [create_jobs(500) for _ in range(num_iter)]

    def compute_avg_and_std_ratios(g):
        print(f"Processing g = {g}")
        ratios_alg1 = []
        ratios_alg1_threshold_1 = []
        ratios_rr = []
        ratios_etc = []
        for jobs in instances:
            jobs = copy.deepcopy(jobs)
            sample_signals(jobs, g)
            opt = opt_val(jobs)

            alg1 = algorithm1(jobs, threshold_opt(g))
            alg1_threshold_1 = algorithm1(jobs, 1)
            rr_result = rr(jobs)
            etc_result = etc(jobs, threshold_opt(g), g)

            ratios_alg1.append(alg1 / opt)
            ratios_alg1_threshold_1.append(alg1_threshold_1 / opt)
            ratios_rr.append(rr_result / opt)
            ratios_etc.append(etc_result / opt)

        return (
            g,
            np.mean(ratios_alg1),
            np.std(ratios_alg1),
            np.mean(ratios_alg1_threshold_1),
            np.std(ratios_alg1_threshold_1),
            np.mean(ratios_rr),
            np.std(ratios_rr),
            np.mean(ratios_etc),
            np.std(ratios_etc),
        )

    results = [compute_avg_and_std_ratios(g) for g in g_values]

    # Save results to a file using pickle
    with open(f"{name.split('.')[0]}_results.pkl", "wb") as f:
        pkl.dump(results, f)


def create_plots(name):

    # Load results from the pickle file
    with open(f"{name.split('.')[0]}_results.pkl", "rb") as f:
        results = pkl.load(f)

    (
        g_values,
        avg_ratios_alg1,
        std_ratios_alg1,
        avg_ratios_alg1_threshold_1,
        std_ratios_alg1_threshold_1,
        avg_ratios_rr,
        std_ratios_rr,
        avg_ratios_etc,
        std_ratios_etc,
    ) = zip(*results)

    plt.figure(figsize=(6, 4.5))

    avg_ratios_rr = np.array(avg_ratios_rr)
    std_ratios_rr = np.array(std_ratios_rr)
    
    plt.plot(
        g_values,
        avg_ratios_rr,
        #color="blue",
        label="RR"
    )
    plt.fill_between(g_values, avg_ratios_rr - std_ratios_rr, avg_ratios_rr + std_ratios_rr, alpha=0.15)

    avg_ratios_etc = np.array(avg_ratios_etc)
    std_ratios_etc = np.array(std_ratios_etc)
    plt.plot(
        g_values,
        avg_ratios_etc,
        linestyle="-.",
        #color="red",
        label="ETC",
    )
    plt.fill_between(g_values, avg_ratios_etc - std_ratios_etc, avg_ratios_etc + std_ratios_etc, alpha=0.15)
    #plt.errorbar(
    #    g_values, avg_ratios_alg1, yerr=std_ratios_alg1, fmt="o", capsize=5, capthick=1, #elinewidth=1.5, color="blue"
    #)
    avg_ratios_alg1_threshold_1 = np.array(avg_ratios_alg1_threshold_1)
    std_ratios_alg1_threshold_1 = np.array(std_ratios_alg1_threshold_1)
    plt.plot(
        g_values,
        avg_ratios_alg1_threshold_1,
        linestyle="--",
        #color="purple",
        label=r"Alg 3, $k = 1$",
    )
    plt.fill_between(g_values, avg_ratios_alg1_threshold_1 - std_ratios_alg1_threshold_1, avg_ratios_alg1_threshold_1 + std_ratios_alg1_threshold_1, alpha=0.15)

    plt.xlabel(r"Granularity $g$")
   
    avg_ratios_alg1 = np.array(avg_ratios_alg1)
    std_ratios_alg1 = np.array(std_ratios_alg1)
    plt.plot(
        g_values,
        avg_ratios_alg1,
        linestyle="-",
        #color="blue",
        label=r"Alg 3, $k = \Theta(g^{2/3})$",
    )
    plt.fill_between(g_values, avg_ratios_alg1 - std_ratios_alg1, avg_ratios_alg1 + std_ratios_alg1, alpha=0.15)

    plt.ylabel("Empirical competitive ratio")
    plt.tick_params(axis='both', which='major')
    #plt.title("Empirical Competitive Ratio vs Granularity", fontsize=16)
    plt.ylim(0.9, 2.5)
    #plt.grid(True, which="both", linestyle="--", linewidth=0.5)
    plt.xscale("log")
    plt.legend(loc="upper right",ncol=2)
    plt.tight_layout()
    plt.savefig(name+".pdf", format="pdf")



create_data(50, "exp_pareto_1_1")
create_plots("exp_pareto_1_1")
