# -*- coding: utf-8 -*-
"""Utils.ipynb
"""

import matplotlib.pyplot as plt
import numpy as np

def plot_results(recorded_date, _title, _label, _xlabel, _ylabel, upperbound = False):
  plt.plot(recorded_date, label=_label)
  if upperbound != False:
    upper_array = upperbound*np.ones(np.shape(recorded_date))
    plt.plot(upper_array, label="Upper Bound")
  plt.title(_title)
  plt.xlabel(_xlabel)
  plt.ylabel(_ylabel)
  plt.legend()
  plt.show()

def plot_results_variance(recorded_date, vara_data  ,_title, _label, _xlabel, _ylabel, shift = 2, jump_point = False):
  plt.plot(recorded_date, label=_label)
  plt.fill_between(range(len(recorded_date)), recorded_date - vara_data, recorded_date + vara_data, alpha=0.3,  label='Variance')
  
  if jump_point !=False:
    # Draw dashed red line parallel to y-axis at the specific point
    plt.axvline(x=jump_point, color='red', linestyle='--')
    # Add text annotation
    plt.text(jump_point, max(recorded_date) + shift, "Agent jump the gap here", color='red')

  #plt.title(_title)
  plt.xlabel(_xlabel)
  plt.ylabel(_ylabel)
  plt.legend()
  plt.show()

def regret_upperbound(K, d, l_K, tau, omega, delta, beta_1, beta_2, H=1):
  T = K*H
  log_term1 = np.log(np.divide(d*T,delta))
  term1 = 2*H*np.sqrt(T*log_term1)
  num2 = 2*beta_2*np.power(H,2)*(omega+1)+2*beta_1
  denum2 = l_K*tau*omega
  log_term2 = np.log(1+ np.divide(K,d))
  term2 = num2*np.sqrt(d*K*log_term2)/denum2
  return term1 + term2
  