import matplotlib.pyplot as plt
import os
import numpy as np
import math

from hip.utils import setup_seaborn
setup_seaborn(
    label_fontsize=11,
    legend_fontsize=8,
    axes_label_fontsize=8,
    axis_below=True,
)

import seaborn as sb

HIP_CUDA = 'HiP w/o Offload'
HIP_UVM = 'HiP UVM w/o Cache'
HIP_CACHE = 'HiP UVM w/ Vector Map'
HIP_HASHMAP = 'HiP UVM w/ Hash Map'
FA_CUDA = 'Flash Attn w/o Offload'
FA_UVM = 'Flash Attn UVM'

def proc_copy_paste(t: str, scale: float = 1):
    return list(map(lambda x: float(x.replace(',', '')) * scale, t.split()))

Ts = [8, 16, 32, 64]

LINEWIDTH = {
    FA_CUDA: 5,
    HIP_CUDA: 3,
}

DEAD_REASON = {
    FA_CUDA: 'OOM',
    HIP_CUDA: 'OOM',
}

gpu_memory_data = {
    FA_CUDA: proc_copy_paste('4001	3918	NaN	NaN'),
    HIP_CUDA: proc_copy_paste('4001	3,918.00	NaN	NaN'),
    FA_UVM: proc_copy_paste('-1000	-1000	-1000	-1000'),
    HIP_UVM: proc_copy_paste('-1000	-1000	-1000	-1000'),
    HIP_CACHE: proc_copy_paste('4510	4600	3598	2283'),
    HIP_HASHMAP: proc_copy_paste('4554	4554	3416	2104'),
}

cpu_memory_data = {
    HIP_CACHE: proc_copy_paste('16,004.00	31,344.00	48,780.00	41,425.50'),
    HIP_HASHMAP: proc_copy_paste('15,556.00	31,344.00	48,780.00	41,257.50'),
}

decode_throughput_data = {
    FA_CUDA: proc_copy_paste('183.4	92.0	NaN	NaN'),
    HIP_CUDA: proc_copy_paste('187.8	94.2	NaN	NaN'),
    FA_UVM: proc_copy_paste('13.4	6.8	3.3	1.9'),
    HIP_UVM: proc_copy_paste('27.3	26.3	24.8	22.9'),
    HIP_CACHE: proc_copy_paste('174.2	154.1	125.1	95.5'),
    HIP_HASHMAP: proc_copy_paste('32.5	25.0	20.5	10.2'),
}

MARKERS = {
    HIP_CUDA: ',',
    FA_CUDA: ',',
    HIP_UVM: '^',
    FA_UVM: '^',
    HIP_CACHE: '*',
    HIP_HASHMAP: '*',
}

root = './saves/plot_offloading'
os.makedirs(root, exist_ok=True)

fig, ax1 = plt.subplots(figsize=(4, 3))

def render_data(plot_data, ax=None, linestyle='-'):
    def render_line(xs, ys, label, method=None):
        # line_ax = sb.lineplot(
        #     x=xs, 
        #     y=ys, 
        #     label=label, 
        #     ax=ax, 
        #     legend=False, 
        #     linewidth=3.0, 
        #     linestyle=linestyle, 
        #     markers=True,
        #     markersize=10,
        # )
        line_ax = (ax if ax is not None else plt).plot(
            xs, ys, 
            label=label, 
            linewidth=LINEWIDTH.get(label, 3), 
            linestyle=linestyle, 
            marker=MARKERS[label], 
            markersize=10,
        )
        if any(map(math.isnan, ys)):
            for last_okay, y in enumerate(ys):
                if math.isnan(y):
                    last_okay -= 1
                    break
            last_okay_x = xs[last_okay]
            last_okay_y = ys[last_okay]
            base_color = line_ax[-1].get_color()
            line_color = tuple(map(lambda x: x*0.66, base_color))
            font_color = tuple(map(lambda x: x*0.66, base_color))
            font_color = 'darkgray'
            (ax if ax is not None else plt).annotate(
                DEAD_REASON[method], 
                xy=(last_okay_x, last_okay_y), 
                xytext=(last_okay_x + 2, last_okay_y),
                fontsize=10,
                va='center',
                fontweight=800,
                linespacing=0.9,
                color=font_color,
                zorder=100,
            )
            (ax if ax is not None else plt).plot(
                [last_okay_x], [last_okay_y],
                marker='x', color=line_color, markersize=10, zorder=100,
            )
    
    for label, data in plot_data.items():
        if isinstance(data, dict):
            for inner_label, inner_data in data.items():
                render_line(Ts, inner_data, f'{label} {inner_label}', method=label)
        else:
            render_line(Ts, data, label, method=label)

ax2 = ax1.twinx()
render_data(decode_throughput_data, ax=ax2)
render_data(gpu_memory_data, ax=ax1, linestyle=':')
# render_data(cpu_memory_data, ax=ax1, linestyle=':')

ax2.legend()
ax2.set_title('Decode Throughput and GPU KV Memory', fontsize=13, pad=12)
ax2.set_ylabel('Decode Throughput (tok/s) ↑', labelpad=5)
ax1.set_xlabel('$T$ (k)')

ax1.set_ylim(-200, 5000)
ax1.set_ylabel('GPU KV Memory (MB)', labelpad=5)
ax2.grid(False)

plt.savefig(os.path.join(root, 'plot_offloading.png'), bbox_inches='tight', pad_inches=0)
plt.savefig(os.path.join(root, 'plot_offloading.pdf'), bbox_inches='tight', pad_inches=0)