# The first four functions in this file are adapted from:
# Zhao et al., "Learning-Augmented Algorithms for the Bahncard Problem", 2024.
# Original repository: https://github.com/Natureal/PFSUM

from matplotlib.cbook import print_cycles
import random
import math
import numpy as np
import matplotlib.pyplot as plt

def interval_generator(key, mean):

    if (key == "Exponential"):
        return max(1, round(np.random.exponential(mean)))

    return 1

def price_generator(key, mean):

    if (key == "Normal"):
        return max(1, round(np.random.normal(mean, mean / 2)))
    elif (key == "Pareto"):
        return max(1, round(np.random.pareto(2) * mean))

# generate a full instance of some time length
def instance_generator(length, key1, mean1, key2, mean2, supremum):

    instance = [{0}] * length
    idx = 0

    while (idx < length):
        price = min(supremum, price_generator(key2, mean2))
        instance[idx]={price}

        interval = interval_generator(key1, mean1)
        idx += interval

    return instance

# generate a noisy instance from an instance
def noisy_instance_generator(instance, key2, mean2, perturb_prob):

    instance_noisy = [0] * len(instance)

    for i in range(0, len(instance)):
        instance_noisy[i] = instance[i].copy()

        drop = (np.random.uniform(0, 1) < perturb_prob)
        add = (np.random.uniform(0, 1) < perturb_prob)
        minus = (np.random.uniform(0, 1) < perturb_prob)

        if (drop):
            instance_noisy[i] = {0}

        if (add):
            noise = price_generator(key2, mean2)
            instance_noisy[i].add(noise)

    return instance_noisy

# create a prediction (for PFSUM) from an instance
def prediction_generator(instance, t):

    prediction = []
    pre_sum = [instance[0]]

    for i in range(1, len(instance)):
        v = instance[i] + pre_sum[i - 1]
        pre_sum.append(v)

    for i in range(0, len(instance)):
        prediction.append(pre_sum[min(i + t - 1, len(instance) - 1)] - pre_sum[i])

    return prediction

# create a sample from an instance
def instance_sampler(epsilon, A):

    sampled = []
    for S in A:
        new_set = {x for x in S if random.random() < epsilon}
        sampled.append(new_set)

    return sampled

# generate a cluster instance
def intermittent_arrivals_generator(length, interval_mean, cluster_mean, mean_cost, cost_type, supremum):
  instance = [{0}] * length
  idx = 0
  counter = 0

  while (idx < length):
      price = min(supremum, max(1,price_generator(cost_type, mean_cost)))
      instance[idx] = {price}
      cluster = interval_generator("Normal", cluster_mean)
      if counter == 30:
        idx += interval_generator("Exponential", interval_mean)
        counter = 0
        continue
      idx += 1
      counter +=1

  return instance