import matplotlib.pyplot as plt
import numpy as np

labels = {
    'js': 'StragglAR',
    'rhd': 'RH/D',
    'ring': 'Ring',
    'direct': 'Broadcast',
    'allpairs': 'MSCCL',
}

colors = {
    'js':     'dodgerblue',
    'rhd':    '#F2C14E',
    'ring':   '#90BE6D',
    'direct': '#F28482',
    'allpairs':  '#B0A8A0'
}

markers = {
    'js': 'o',
    'allpairs': 's', 
    'rhd': '^', 
    'ring': 'D', 
    'direct': 'v'
}

a = 3 / 1000                      # alpha
b = 1000 / (450 * 1024 ** 3)      # beta
s = 1 * 1024**3                   # buffer size

def ring_cost(n, a, b, s):
    return (2 * (n - 1)) * a + ((2 * (n - 1)) / n) * s * b

def rhd_cost(n, a, b, s):
    return (2 * np.log2(n)) * a + ((2 * (n - 1)) / n) * s * b

def direct_cost(n, a, b, s):
    return np.log2(n) * a + np.log2(n) * s * b

def msccl_cost(n, a, b, s):
    # "Allpairs
    return 2 * a + ((2 * (n - 1)) / n) * s * b

def stragglar_best_cost(n, a, b, s):
    global SAR_cofficients
    # Full overlap
    return ((n - 2 + np.log2(n)) * a + ((n - 2 + np.log2(n)) / (n - 1)) * s * b)

def stragglar_worst_cost(n, a, b, s):
    global SAR_cofficients
    # No overlap
    base = stragglar_best_cost(n, a, b, s)
    return ((n - 2) * a + ((n - 2) / (n - 1)) * s * b + base)


powers_of_two = [8, 16, 32, 64, 128, 256]

def speedups_over_ring(N):
    sp = {}
    sp['ring'] = []
    sp['rhd'] = []
    sp['direct'] = []
    sp['allpairs'] = []
    sp['js_best'] = []
    sp['js_worst'] = []

    for n in N:
        r  = ring_cost(n, a, b, s)
        h  = rhd_cost(n, a, b, s)
        d  = direct_cost(n, a, b, s)
        m  = msccl_cost(n, a, b, s)
        jb = stragglar_best_cost(n, a, b, s)
        jw = stragglar_worst_cost(n, a, b, s)

        sp['ring'].append(1.0)          # ring vs ring
        sp['rhd'].append(r / h)
        sp['direct'].append(r / d)
        sp['allpairs'].append(r / m)
        sp['js_best'].append(r / jb)
        sp['js_worst'].append(r / jw)

    return sp

sp = speedups_over_ring(powers_of_two)

plt.figure(figsize=(7, 4))
plt.xscale('log', base=2)
plt.rcParams.update({'font.size': 18})

plt.plot(powers_of_two, sp['js_best'], label='StragglAR (ideal)', color=colors['js'], marker=markers['js'], lw=3)

plt.plot(powers_of_two, sp['js_worst'], label='StragglAR (worst)', color=colors['js'], linestyle='--', marker='o', lw=3, ms=6, markerfacecolor='none')

plt.fill_between(powers_of_two, sp['js_worst'], sp['js_best'], alpha=0.13, color=colors['js'])

plt.plot(powers_of_two, sp['allpairs'], label='MSCCL', color=colors['allpairs'], marker=markers['allpairs'], lw=3)

plt.plot(powers_of_two, sp['rhd'], label='RH/D', color=colors['rhd'], marker=markers['rhd'], lw=3)

plt.plot(powers_of_two, sp['ring'], label='Ring', color='black', lw=3.2)

plt.plot(powers_of_two, sp['direct'], label='Broadcast', color=colors['direct'], marker=markers['direct'], lw=3)

plt.xticks(powers_of_two, labels=powers_of_two, fontsize=18)
plt.tick_params(axis='y', labelsize=18)
plt.xlabel('Number of GPUs', fontsize=22)
plt.ylabel('Speedup over Ring', fontsize=22)
plt.ylim(0, 2)

plt.legend(frameon=False, fontsize=18, ncol=3, loc='upper center', bbox_to_anchor=(0.5, 1.3))

plt.show()