import random
import math
import numpy as np
import matplotlib.pyplot as plt
import heapq
from typing import List, Tuple, Any, Dict

# Compute the offline optimum on an instance
def offline_opt(instance, beta, card, time_unit):

  instance_copy=instance.copy()
  instance_discount=[(1-beta) * x for x in instance.copy()]
  k = math.floor(1/time_unit)
  length = len(instance)

  target = length
  dist = [float("inf")] * (length + 1)
  dist[0] = 0.0
  pq: List[Tuple[float, int]] = [(0.0, 0)]

  while pq:
      d, i = heapq.heappop(pq)
      if d != dist[i]:
          continue
      if i == target:
          break
      if i < length:

          j1 = i + 1
          nd1 = d + float(instance_discount[i])
          if nd1 < dist[j1]:
              dist[j1] = nd1
              heapq.heappush(pq, (nd1, j1))

          j2 = i + k
          if j2 > length:
              j2 = length
          nd2 = d + float(card)
          if nd2 < dist[j2]:
              dist[j2] = nd2
              heapq.heappush(pq, (nd2, j2))

  total_cost=dist[target]+beta*sum(instance)

  return total_cost


# Counter algorithm
def online_no_sample(instance, beta, card, time_unit):

  instance_discount=[(1-beta) * x for x in instance]
  k = math.floor(1/time_unit)
  n = len(instance)

  v = 0
  counter = 0.0
  relevant_cost = 0.0
  path = [0]
  actions = []

  while v < n:
      if counter + instance_discount[v] > card:
          j = v + k
          if j > n:
              j = n
          relevant_cost += card
          actions.append(("jump", v, j, card, 0.0))
          counter = 0.0
          v = j
          path.append(v)
      else:
          relevant_cost += instance_discount[v]
          counter += instance_discount[v]
          v_next = v + 1
          actions.append(("step", v, v_next, instance_discount[v], counter))
          v = v_next
          path.append(v)

  total_cost=relevant_cost+beta*sum(instance)
  return total_cost

# PFSUM algorithm
def PFSUM(instance, beta, card, time_unit, prediction):
  T=math.floor(1/time_unit)
  C = card
  length = len(instance)
  if (length == 0):
      return (0, [])
  gamma = C / (1 - beta)

  cost = 0
  solution = []
  last_buy_time = -T
  T_recent_cost = 0

  for i in range(0, length):
      if (i - T >= 0):
          T_recent_cost -= instance[i - T]

      if (last_buy_time + T - 1 >= i):
          cost += beta * instance[i]
      else:
          if (T_recent_cost + instance[i] >= gamma and instance[i] + prediction[i] >= gamma):
              cost += C + beta * instance[i]
              last_buy_time = i
              solution.append(i)
          else:
              cost += instance[i]

      T_recent_cost += instance[i]

  return cost

# Naïve algorithm
def naive_sample(instance, beta, card, time_unit, sample, sample_rate):

    k = math.floor(1/time_unit)
    eps = sample_rate
    B = [(1-beta) * x for x in instance]
    A = [(1/eps) * (1-beta) * x for x in sample]
    C = card

    n = len(A)
    dist_A = [float("inf")] * (n + 1)
    dist_A[0] = 0.0
    prev = [-1] * (n + 1)    
    prev_move = [""] * (n + 1) 

    for i in range(n):
        if dist_A[i] == float("inf"):
            continue

        j1 = i + 1
        nd1 = dist_A[i] + A[i]
        if nd1 < dist_A[j1]:
            dist_A[j1] = nd1
            prev[j1] = i
            prev_move[j1] = "step"

        j2 = i + k
        if j2 > n:
            j2 = n
        nd2 = dist_A[i] + C
        if nd2 < dist_A[j2]:
            dist_A[j2] = nd2
            prev[j2] = i
            prev_move[j2] = "jump"

    vertex_path = []
    move_types = []
    v = n
    while v != 0:
        vertex_path.append(v)
        move_types.append(prev_move[v])
        v = prev[v]
    vertex_path.append(0)
    vertex_path.reverse()
    move_types.reverse()

    cost_on_B = 0.0
    for idx, move in enumerate(move_types):
        u = vertex_path[idx]
        if move == "step":
            cost_on_B += B[u]
        elif move == "jump":
            cost_on_B += C
        else:
            raise RuntimeError(f"Unknown move type: {move}")

    return cost_on_B+beta*sum(instance)

# Necessary for Algorithms 5.2 and 5.3 (find long gaps)
def maximal_step_runs(vertex_path, move_types, m, length):

    runs = []
    i = 0
    L = len(move_types)

    while i < L:
        if move_types[i] != "step":
            i += 1
            continue

        start = i
        while i < L and move_types[i] == "step":
            i += 1
        end = i - 1

        num_steps = end - start + 1
        step_vertices = vertex_path[start : end + 1]
        runs.append({
            "vertices": step_vertices,
            "num_steps": num_steps,
            "is_long": num_steps >= m,
        })
    long_runs = [
    r for i, r in enumerate(runs)
    if r["is_long"] or i == 0 or i == len(runs) - 1
    ]

    return long_runs

# Algorithm 5.2
def three_two_sample(instance, beta, card, time_unit, sample, xi, sample_rate):

    k = math.floor(1/time_unit)
    eps = sample_rate
    B = [(1-beta) * x for x in instance]
    A = [(1/eps) * (1-beta) * x for x in sample]
    C = card

    n = len(A)
    dist_A = [float("inf")] * (n + 1)
    dist_A[0] = 0.0
    prev = [-1] * (n + 1)
    prev_move = [""] * (n + 1)

    for i in range(n):
        if dist_A[i] == float("inf"):
            continue

        j1 = i + 1
        nd1 = dist_A[i] + A[i]
        if nd1 < dist_A[j1]:
            dist_A[j1] = nd1
            prev[j1] = i
            prev_move[j1] = "step"

        j2 = i + k
        if j2 > n:
            j2 = n
        nd2 = dist_A[i] + C
        if nd2 < dist_A[j2]:
            dist_A[j2] = nd2
            prev[j2] = i
            prev_move[j2] = "jump"

    vertex_path = []
    move_types = []
    v = n
    while v != 0:
        vertex_path.append(v)
        move_types.append(prev_move[v])
        v = prev[v]
    vertex_path.append(0)
    vertex_path.reverse()
    move_types.reverse()

    long_gaps = maximal_step_runs(vertex_path, move_types, max(1,math.floor(xi/time_unit)), len(instance))

    if len(long_gaps) <= 1:
      cost_on_B = 0.0
      for idx, move in enumerate(move_types):
          u = vertex_path[idx]
          if move == "step":
              cost_on_B += B[u]
          elif move == "jump":
              cost_on_B += C
          else:
              raise RuntimeError(f"Unknown move type: {move}")

      return cost_on_B+beta*sum(instance)

    total_cost=sum(B[v] for r in long_gaps for v in r["vertices"])

    if not long_gaps[0]["is_long"] and not long_gaps[0]["vertices"][0]==0:
      long_gaps = long_gaps[1:]
    if not long_gaps[-1]["is_long"] and not long_gaps[-1]["vertices"][-1]==len(instance)-1:
      long_gaps = long_gaps[:-1]

    if len(long_gaps) == 0:
      return beta*sum(instance)+math.ceil(n * time_unit) * C

    gap = long_gaps[0]["vertices"][0]
    total_cost += math.ceil(gap * time_unit) * C
    gap = n-(long_gaps[-1]["vertices"][-1]+1)
    total_cost += math.ceil(gap * time_unit) * C

    for r1, r2 in zip(long_gaps, long_gaps[1:]):
        S = r1["vertices"]
        T = r2["vertices"]
        s = S[-1]+1
        t = T[0]
        gap = t - s
        total_cost += math.ceil(gap * time_unit) * C

    return total_cost+beta*sum(instance)

# Algorithm 5.3
def improved_sample(instance, beta, card, time_unit, sample, xi, sample_rate):

    k = math.floor(1/time_unit)
    eps = sample_rate
    B = [(1-beta) * x for x in instance]
    A = [(1/eps) * (1-beta) * x for x in sample]
    C = card

    n = len(A)
    dist_A = [float("inf")] * (n + 1)
    dist_A[0] = 0.0
    prev = [-1] * (n + 1) 
    prev_move = [""] * (n + 1) 

    for i in range(n):
        if dist_A[i] == float("inf"):
            continue

        j1 = i + 1
        nd1 = dist_A[i] + A[i]
        if nd1 < dist_A[j1]:
            dist_A[j1] = nd1
            prev[j1] = i
            prev_move[j1] = "step"

        j2 = i + k
        if j2 > n:
            j2 = n
        nd2 = dist_A[i] + C
        if nd2 < dist_A[j2]:
            dist_A[j2] = nd2
            prev[j2] = i
            prev_move[j2] = "jump"

    vertex_path = []
    move_types = []
    v = n
    while v != 0:
        vertex_path.append(v)
        move_types.append(prev_move[v])
        v = prev[v]
    vertex_path.append(0)
    vertex_path.reverse()
    move_types.reverse()

    long_gaps = maximal_step_runs(vertex_path, move_types, math.floor(xi/time_unit), len(instance))

    if len(long_gaps)<=1:
      cost_on_B = 0.0
      for idx, move in enumerate(move_types):
          u = vertex_path[idx]
          if move == "step":
              cost_on_B += B[u]
          elif move == "jump":
              cost_on_B += C
          else:
              raise RuntimeError(f"Unknown move type: {move}")

      return cost_on_B+beta*sum(instance)

    total_cost=sum(B[v] for r in long_gaps for v in r["vertices"])

    pos = {v: i for i, v in enumerate(vertex_path)}

    def edge_cost(edge_idx) -> float:
        u = vertex_path[edge_idx]
        mv = move_types[edge_idx]
        if mv == "step":
            return B[u]
        elif mv == "jump":
            return C
        else:
            raise ValueError(f"Unknown move type: {mv}")

    def subpath_cost_and_jumps(v_from, v_to):
        i_from = pos[v_from]
        i_to = pos[v_to]
        cost = 0.0
        jumps = 0
        for e in range(i_from, i_to):
            if move_types[e] == "jump":
                jumps += 1
            cost += edge_cost(e)
        return cost, jumps

    if not long_gaps[0]["is_long"] and not long_gaps[0]["vertices"][0]==0:
      long_gaps = long_gaps[1:]
    if not long_gaps[-1]["is_long"] and not long_gaps[-1]["vertices"][-1]==len(instance)-1:
      long_gaps = long_gaps[:-1]

    if len(long_gaps) == 0:
      return beta*sum(instance)+math.ceil(n * time_unit) * C

    gap = long_gaps[0]["vertices"][0]
    total_cost += math.ceil(gap * time_unit) * C
    gap = n-(long_gaps[-1]["vertices"][-1] + 1)
    total_cost += math.ceil(gap * time_unit) * C

    for r1, r2 in zip(long_gaps, long_gaps[1:]):
        s = r1["vertices"][-1] + 1 
        t = r2["vertices"][0]

        if s == t:
            continue

        normal_cost, jump_count = subpath_cost_and_jumps(s, t)

        if jump_count > 1/xi:
            gap = t - s
            total_cost += math.ceil(gap * time_unit) * C
        else:
            total_cost += normal_cost

    return total_cost+beta*sum(instance)

